diff --git a/src/alerts/alerts_utils.rs b/src/alerts/alerts_utils.rs index eeabb29bd..a3b9d5a45 100644 --- a/src/alerts/alerts_utils.rs +++ b/src/alerts/alerts_utils.rs @@ -29,10 +29,15 @@ use datafusion::{ logical_expr::{BinaryExpr, Literal, Operator}, prelude::{col, lit, DataFrame, Expr}, }; -use tracing::trace; +use tokio::task::JoinSet; +use tracing::{trace, warn}; use crate::{ - alerts::LogicalOperator, parseable::PARSEABLE, query::QUERY_SESSION, utils::time::TimeRange, + alerts::LogicalOperator, + handlers::http::query::update_schema_when_distributed, + parseable::PARSEABLE, + query::{resolve_stream_names, QUERY_SESSION}, + utils::time::TimeRange, }; use super::{ @@ -71,11 +76,37 @@ async fn prepare_query(alert: &AlertConfig) -> Result plan, + Err(_) => { + let mut join_set = JoinSet::new(); + for stream_name in streams { + let stream_name = stream_name.clone(); + join_set.spawn(async move { + let result = PARSEABLE + .create_stream_and_schema_from_storage(&stream_name) + .await; + + if let Err(e) = &result { + warn!("Failed to create stream '{}': {}", stream_name, e); + } + + (stream_name, result) + }); + } + + while let Some(result) = join_set.join_next().await { + if let Err(join_error) = result { + warn!("Task join error: {}", join_error); + } + } + session_state.create_logical_plan(&select_query).await? + } + }; Ok(crate::query::Query { raw_logical_plan, time_range, @@ -87,11 +118,18 @@ async fn execute_base_query( query: &crate::query::Query, original_query: &str, ) -> Result { - let stream_name = query.first_table_name().ok_or_else(|| { + let streams = resolve_stream_names(original_query)?; + let stream_name = streams.first().ok_or_else(|| { AlertError::CustomError(format!("Table name not found in query- {original_query}")) })?; - - let time_partition = PARSEABLE.get_stream(&stream_name)?.get_time_partition(); + update_schema_when_distributed(&streams) + .await + .map_err(|err| { + AlertError::CustomError(format!( + "Failed to update schema for distributed streams: {err}" + )) + })?; + let time_partition = PARSEABLE.get_stream(stream_name)?.get_time_partition(); query .get_dataframe(time_partition.as_ref()) .await diff --git a/src/alerts/mod.rs b/src/alerts/mod.rs index 50ed91276..2fe458a2e 100644 --- a/src/alerts/mod.rs +++ b/src/alerts/mod.rs @@ -19,6 +19,7 @@ use actix_web::http::header::ContentType; use async_trait::async_trait; use chrono::Utc; +use datafusion::sql::sqlparser::parser::ParserError; use derive_more::derive::FromStr; use derive_more::FromStrError; use http::StatusCode; @@ -860,6 +861,8 @@ pub enum AlertError { InvalidTargetModification(String), #[error("Can't delete a Target which is being used")] TargetInUse, + #[error("{0}")] + ParserError(#[from] ParserError), } impl actix_web::ResponseError for AlertError { @@ -880,6 +883,7 @@ impl actix_web::ResponseError for AlertError { Self::InvalidTargetID(_) => StatusCode::BAD_REQUEST, Self::InvalidTargetModification(_) => StatusCode::BAD_REQUEST, Self::TargetInUse => StatusCode::CONFLICT, + Self::ParserError(_) => StatusCode::BAD_REQUEST, } } diff --git a/src/correlation.rs b/src/correlation.rs index c5f4eb2d8..f7bb65eec 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -87,7 +87,7 @@ impl Correlations { .iter() .map(|t| t.table_name.clone()) .collect_vec(); - if user_auth_for_datasets(&permissions, tables).is_ok() { + if user_auth_for_datasets(&permissions, tables).await.is_ok() { user_correlations.push(correlation.clone()); } } @@ -281,7 +281,7 @@ impl CorrelationConfig { .map(|t| t.table_name.clone()) .collect_vec(); - user_auth_for_datasets(&permissions, tables)?; + user_auth_for_datasets(&permissions, tables).await?; // to validate table config, we need to check whether the mentioned fields // are present in the table or not diff --git a/src/event/mod.rs b/src/event/mod.rs index c60f0d057..64c942de9 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -123,7 +123,7 @@ pub fn commit_schema(stream_name: &str, schema: Arc) -> Result<(), Stagi let map = &mut stream_metadata .get_mut(stream_name) - .expect("map has entry for this stream name") + .ok_or_else(|| StagingError::NotFound(stream_name.to_string()))? .metadata .write() .expect(LOCK_EXPECT) diff --git a/src/handlers/airplane.rs b/src/handlers/airplane.rs index 8831346eb..4ed2f3950 100644 --- a/src/handlers/airplane.rs +++ b/src/handlers/airplane.rs @@ -20,8 +20,6 @@ use arrow_array::RecordBatch; use arrow_flight::flight_service_server::FlightServiceServer; use arrow_flight::PollInfo; use arrow_schema::ArrowError; - -use datafusion::common::tree_node::TreeNode; use serde_json::json; use std::net::SocketAddr; use std::time::Instant; @@ -35,11 +33,11 @@ use tonic_web::GrpcWebLayer; use crate::handlers::http::cluster::get_node_info; use crate::handlers::http::modal::{NodeMetadata, NodeType}; -use crate::handlers::http::query::{into_query, update_schema_when_distributed}; +use crate::handlers::http::query::into_query; use crate::handlers::livetail::cross_origin_config; use crate::metrics::QUERY_EXECUTE_TIME; use crate::parseable::PARSEABLE; -use crate::query::{execute, TableScanVisitor, QUERY_SESSION}; +use crate::query::{execute, resolve_stream_names, QUERY_SESSION}; use crate::utils::arrow::flight::{ append_temporary_events, get_query_from_ticket, into_flight_data, run_do_get_rpc, send_to_ingester, @@ -131,40 +129,26 @@ impl FlightService for AirServiceImpl { let ticket = get_query_from_ticket(&req).map_err(|e| Status::invalid_argument(e.to_string()))?; - + let streams = resolve_stream_names(&ticket.query).map_err(|e| { + error!("Failed to extract table names from SQL: {}", e); + Status::invalid_argument("Invalid SQL query syntax") + })?; info!("query requested to airplane: {:?}", ticket); // get the query session_state let session_state = QUERY_SESSION.state(); - // get the logical plan and extract the table name - let raw_logical_plan = session_state - .create_logical_plan(&ticket.query) - .await - .map_err(|err| { - error!("Datafusion Error: Failed to create logical plan: {}", err); - Status::internal("Failed to create logical plan") - })?; - let time_range = TimeRange::parse_human_time(&ticket.start_time, &ticket.end_time) .map_err(|e| Status::internal(e.to_string()))?; // create a visitor to extract the table name - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); - - let streams = visitor.into_inner(); let stream_name = streams .first() .ok_or_else(|| Status::aborted("Malformed SQL Provided, Table Name Not Found"))? .to_owned(); - update_schema_when_distributed(&streams) - .await - .map_err(|err| Status::internal(err.to_string()))?; - // map payload to query - let query = into_query(&ticket, &session_state, time_range) + let query = into_query(&ticket, &session_state, time_range, &streams) .await .map_err(|_| Status::internal("Failed to parse query"))?; @@ -214,9 +198,11 @@ impl FlightService for AirServiceImpl { let permissions = Users.get_permissions(&key); - user_auth_for_datasets(&permissions, &streams).map_err(|_| { - Status::permission_denied("User Does not have permission to access this") - })?; + user_auth_for_datasets(&permissions, &streams) + .await + .map_err(|_| { + Status::permission_denied("User Does not have permission to access this") + })?; let time = Instant::now(); let (records, _) = execute(query, &stream_name, false) diff --git a/src/handlers/http/correlation.rs b/src/handlers/http/correlation.rs index facd8a64c..f9c77f9da 100644 --- a/src/handlers/http/correlation.rs +++ b/src/handlers/http/correlation.rs @@ -54,7 +54,7 @@ pub async fn get( .map(|t| t.table_name.clone()) .collect_vec(); - user_auth_for_datasets(&permissions, tables)?; + user_auth_for_datasets(&permissions, tables).await?; Ok(web::Json(correlation)) } diff --git a/src/handlers/http/query.rs b/src/handlers/http/query.rs index 85634d031..48ac411ec 100644 --- a/src/handlers/http/query.rs +++ b/src/handlers/http/query.rs @@ -18,6 +18,7 @@ use crate::event::error::EventError; use crate::handlers::http::fetch_schema; +use crate::option::Mode; use crate::utils::arrow::record_batches_to_json; use actix_web::http::header::ContentType; use actix_web::web::{self, Json}; @@ -25,9 +26,9 @@ use actix_web::{Either, FromRequest, HttpRequest, HttpResponse, Responder}; use arrow_array::RecordBatch; use bytes::Bytes; use chrono::{DateTime, Utc}; -use datafusion::common::tree_node::TreeNode; use datafusion::error::DataFusionError; use datafusion::execution::context::SessionState; +use datafusion::sql::sqlparser::parser::ParserError; use futures::stream::once; use futures::{future, Stream, StreamExt}; use futures_util::Future; @@ -44,11 +45,10 @@ use tracing::{error, warn}; use crate::event::commit_schema; use crate::metrics::QUERY_EXECUTE_TIME; -use crate::option::Mode; use crate::parseable::{StreamNotFound, PARSEABLE}; use crate::query::error::ExecuteError; use crate::query::{execute, CountsRequest, Query as LogicalQuery}; -use crate::query::{TableScanVisitor, QUERY_SESSION}; +use crate::query::{resolve_stream_names, QUERY_SESSION}; use crate::rbac::Users; use crate::response::QueryResponse; use crate::storage::ObjectStorageError; @@ -81,33 +81,23 @@ pub async fn get_records_and_fields( query_request: &Query, req: &HttpRequest, ) -> Result<(Option>, Option>), QueryError> { + let tables = resolve_stream_names(&query_request.query)?; let session_state = QUERY_SESSION.state(); - // get the logical plan and extract the table name - let raw_logical_plan = session_state - .create_logical_plan(&query_request.query) - .await?; - let time_range = TimeRange::parse_human_time(&query_request.start_time, &query_request.end_time)?; - // create a visitor to extract the table name - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); - - let tables = visitor.into_inner(); - update_schema_when_distributed(&tables).await?; - let query: LogicalQuery = into_query(query_request, &session_state, time_range).await?; + let query: LogicalQuery = + into_query(query_request, &session_state, time_range, &tables).await?; let creds = extract_session_key_from_req(req)?; let permissions = Users.get_permissions(&creds); - let table_name = query - .first_table_name() + let table_name = tables + .first() .ok_or_else(|| QueryError::MalformedQuery("No table name found in query"))?; - - user_auth_for_datasets(&permissions, &tables)?; - - let (records, fields) = execute(query, &table_name, false).await?; + user_auth_for_datasets(&permissions, &tables).await?; + update_schema_when_distributed(&tables).await?; + let (records, fields) = execute(query, table_name, false).await?; let records = match records { Either::Left(vec_rb) => vec_rb, @@ -121,54 +111,37 @@ pub async fn get_records_and_fields( pub async fn query(req: HttpRequest, query_request: Query) -> Result { let session_state = QUERY_SESSION.state(); - let raw_logical_plan = match session_state - .create_logical_plan(&query_request.query) - .await - { - Ok(raw_logical_plan) => raw_logical_plan, - Err(_) => { - create_streams_for_querier().await?; - session_state - .create_logical_plan(&query_request.query) - .await? - } - }; let time_range = TimeRange::parse_human_time(&query_request.start_time, &query_request.end_time)?; - - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); - let tables = visitor.into_inner(); + let tables = resolve_stream_names(&query_request.query)?; update_schema_when_distributed(&tables).await?; - let query: LogicalQuery = into_query(&query_request, &session_state, time_range).await?; - + let query: LogicalQuery = + into_query(&query_request, &session_state, time_range, &tables).await?; let creds = extract_session_key_from_req(&req)?; let permissions = Users.get_permissions(&creds); - let table_name = query - .first_table_name() + let table_name = tables + .first() .ok_or_else(|| QueryError::MalformedQuery("No table name found in query"))?; - - user_auth_for_datasets(&permissions, &tables)?; - + user_auth_for_datasets(&permissions, &tables).await?; let time = Instant::now(); // if the query is `select count(*) from ` // we use the `get_bin_density` method to get the count of records in the dataset // instead of executing the query using datafusion if let Some(column_name) = query.is_logical_plan_count_without_filters() { - return handle_count_query(&query_request, &table_name, column_name, time).await; + return handle_count_query(&query_request, table_name, column_name, time).await; } // if the query request has streaming = false (default) // we use datafusion's `execute` method to get the records if !query_request.streaming { - return handle_non_streaming_query(query, &table_name, &query_request, time).await; + return handle_non_streaming_query(query, table_name, &query_request, time).await; } // if the query request has streaming = true // we use datafusion's `execute_stream` method to get the records - handle_streaming_query(query, &table_name, &query_request, time).await + handle_streaming_query(query, table_name, &query_request, time).await } /// Handles count queries (e.g., `SELECT COUNT(*) FROM `) @@ -372,7 +345,7 @@ pub async fn get_counts( let body = counts_request.into_inner(); // does user have access to table? - user_auth_for_datasets(&permissions, &[body.stream.clone()])?; + user_auth_for_datasets(&permissions, &[body.stream.clone()]).await?; // if the user has given a sql query (counts call with filters applied), then use this flow // this could include filters or group by @@ -427,7 +400,6 @@ pub async fn update_schema_when_distributed(tables: &Vec) -> Result<(), } } } - Ok(()) } @@ -520,6 +492,7 @@ pub async fn into_query( query: &Query, session_state: &SessionState, time_range: TimeRange, + tables: &Vec, ) -> Result { if query.query.is_empty() { return Err(QueryError::EmptyQuery); @@ -532,9 +505,36 @@ pub async fn into_query( if query.end_time.is_empty() { return Err(QueryError::EmptyEndTime); } + let raw_logical_plan = match session_state.create_logical_plan(&query.query).await { + Ok(plan) => plan, + Err(_) => { + let mut join_set = JoinSet::new(); + for stream_name in tables { + let stream_name = stream_name.clone(); + join_set.spawn(async move { + let result = PARSEABLE + .create_stream_and_schema_from_storage(&stream_name) + .await; + + if let Err(e) = &result { + warn!("Failed to create stream '{}': {}", stream_name, e); + } + + (stream_name, result) + }); + } + + while let Some(result) = join_set.join_next().await { + if let Err(join_error) = result { + warn!("Task join error: {}", join_error); + } + } + session_state.create_logical_plan(&query.query).await? + } + }; Ok(crate::query::Query { - raw_logical_plan: session_state.create_logical_plan(&query.query).await?, + raw_logical_plan, time_range, filter_tag: query.filter_tags.clone(), }) @@ -618,6 +618,8 @@ Description: {0}"# CustomError(String), #[error("No available queriers found")] NoAvailableQuerier, + #[error("{0}")] + ParserError(#[from] ParserError), } impl actix_web::ResponseError for QueryError { diff --git a/src/parseable/staging/mod.rs b/src/parseable/staging/mod.rs index 256133841..60150b9d3 100644 --- a/src/parseable/staging/mod.rs +++ b/src/parseable/staging/mod.rs @@ -30,6 +30,8 @@ pub enum StagingError { ObjectStorage(#[from] std::io::Error), #[error("Could not generate parquet file")] Create, + #[error("Could not find stream {0}")] + NotFound(String), // #[error("Metadata Error: {0}")] // Metadata(#[from] MetadataError), } diff --git a/src/query/mod.rs b/src/query/mod.rs index b56e15328..e624c213c 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -24,7 +24,8 @@ use actix_web::Either; use chrono::NaiveDateTime; use chrono::{DateTime, Duration, Utc}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion::catalog::resolve_table_references; +use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::error::DataFusionError; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; @@ -33,6 +34,8 @@ use datafusion::logical_expr::{ Aggregate, Explain, Filter, LogicalPlan, PlanType, Projection, ToStringifiedPlan, }; use datafusion::prelude::*; +use datafusion::sql::parser::DFParser; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; use itertools::Itertools; use once_cell::sync::Lazy; use relative_path::RelativePathBuf; @@ -259,12 +262,6 @@ impl Query { } } - pub fn first_table_name(&self) -> Option { - let mut visitor = TableScanVisitor::default(); - let _ = self.raw_logical_plan.visit(&mut visitor); - visitor.into_inner().pop() - } - /// Evaluates to Some("count(*)") | Some("column_name") if the logical plan is a Projection: SELECT COUNT(*) | SELECT COUNT(*) as column_name pub fn is_logical_plan_count_without_filters(&self) -> Option<&String> { // Check if the raw logical plan is a Projection: SELECT @@ -488,29 +485,18 @@ pub struct CountsResponse { pub records: Vec, } -#[derive(Debug, Default)] -pub struct TableScanVisitor { - tables: Vec, -} - -impl TableScanVisitor { - pub fn into_inner(self) -> Vec { - self.tables - } -} - -impl TreeNodeVisitor<'_> for TableScanVisitor { - type Node = LogicalPlan; - - fn f_down(&mut self, node: &Self::Node) -> Result { - match node { - LogicalPlan::TableScan(table) => { - self.tables.push(table.table_name.table().to_string()); - Ok(TreeNodeRecursion::Jump) - } - _ => Ok(TreeNodeRecursion::Continue), - } +pub fn resolve_stream_names(sql: &str) -> Result, anyhow::Error> { + let normalized_sql = sql.replace('`', "\""); + let dialect = &PostgreSqlDialect {}; + let statement = DFParser::parse_sql_with_dialect(&normalized_sql, dialect)? + .pop_back() + .ok_or(anyhow::anyhow!("Failed to parse sql"))?; + let (table_refs, _) = resolve_table_references(&statement, true)?; + let mut tables = Vec::new(); + for table in table_refs { + tables.push(table.table().to_string()); } + Ok(tables) } pub async fn get_manifest_list( diff --git a/src/users/filters.rs b/src/users/filters.rs index 42d34bb9d..0106c47f0 100644 --- a/src/users/filters.rs +++ b/src/users/filters.rs @@ -193,7 +193,10 @@ impl Filters { } } else if *filter_type == FilterType::Search || *filter_type == FilterType::Filter { let dataset_name = &f.stream_name; - if user_auth_for_datasets(&permissions, &[dataset_name.to_string()]).is_ok() { + if user_auth_for_datasets(&permissions, &[dataset_name.to_string()]) + .await + .is_ok() + { filters.push(f.clone()) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ee7583ba7..93e3e5980 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -27,14 +27,13 @@ pub mod update; use crate::handlers::http::rbac::RBACError; use crate::parseable::PARSEABLE; -use crate::query::{TableScanVisitor, QUERY_SESSION}; +use crate::query::resolve_stream_names; use crate::rbac::map::SessionKey; use crate::rbac::role::{Action, ParseableResourceType, Permission}; use crate::rbac::Users; use actix::extract_session_key_from_req; use actix_web::HttpRequest; use chrono::{NaiveDate, NaiveDateTime, NaiveTime, Utc}; -use datafusion::common::tree_node::TreeNode; use regex::Regex; use sha2::{Digest, Sha256}; @@ -78,28 +77,18 @@ pub fn get_hash(key: &str) -> String { result } -async fn get_tables_from_query(query: &str) -> Result { - let session_state = QUERY_SESSION.state(); - let raw_logical_plan = session_state - .create_logical_plan(query) - .await - .map_err(|e| actix_web::error::ErrorInternalServerError(format!("Query error: {e}")))?; - - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); - Ok(visitor) -} - pub async fn user_auth_for_query( session_key: &SessionKey, query: &str, ) -> Result<(), actix_web::error::Error> { - let tables = get_tables_from_query(query).await?.into_inner(); + let tables = resolve_stream_names(query).map_err(|e| { + actix_web::error::ErrorBadRequest(format!("Failed to extract table names: {e}")) + })?; let permissions = Users.get_permissions(session_key); - user_auth_for_datasets(&permissions, &tables) + user_auth_for_datasets(&permissions, &tables).await } -pub fn user_auth_for_datasets( +pub async fn user_auth_for_datasets( permissions: &[Permission], tables: &[String], ) -> Result<(), actix_web::error::Error> { @@ -115,6 +104,11 @@ pub fn user_auth_for_datasets( break; } Permission::Resource(Action::Query, ParseableResourceType::Stream(stream)) => { + if !PARSEABLE.check_or_load_stream(stream).await { + return Err(actix_web::error::ErrorUnauthorized(format!( + "Stream not found: {table_name}" + ))); + } let is_internal = PARSEABLE.get_stream(table_name).is_ok_and(|stream| { stream .get_stream_type() @@ -154,22 +148,3 @@ pub fn user_auth_for_datasets( Ok(()) } - -/// A function to extract table names from a SQL string -pub async fn extract_tables(sql: &str) -> Option> { - let session_state = QUERY_SESSION.state(); - - // get the logical plan and extract the table name - let raw_logical_plan = match session_state.create_logical_plan(sql).await { - Ok(plan) => plan, - Err(_) => return None, - }; - - // create a visitor to extract the table name - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); - - let tables = visitor.into_inner(); - - Some(tables) -}