diff --git a/lib/executor/src/context.rs b/lib/executor/src/context.rs index 7982d3034..58b6e1365 100644 --- a/lib/executor/src/context.rs +++ b/lib/executor/src/context.rs @@ -4,7 +4,11 @@ use hive_router_query_planner::planner::plan_nodes::{FetchNode, FetchRewrite, Qu use crate::{ headers::plan::ResponseHeaderAggregator, - response::{graphql_error::GraphQLError, storage::ResponsesStorage, value::Value}, + response::{ + graphql_error::{GraphQLError, GraphQLErrorPath}, + storage::ResponsesStorage, + value::Value, + }, }; pub struct ExecutionContext<'a> { @@ -38,10 +42,20 @@ impl<'a> ExecutionContext<'a> { } } - pub fn handle_errors(&mut self, errors: Option>) { - if let Some(errors) = errors { - for error in errors { - self.errors.push(error); + pub fn handle_errors( + &mut self, + errors: Option>, + entity_index_error_map: Option>>, + ) { + if let Some(response_errors) = errors { + for response_error in response_errors { + if let Some(entity_index_error_map) = &entity_index_error_map { + let normalized_errors = + response_error.normalize_entity_error(entity_index_error_map); + self.errors.extend(normalized_errors); + } else { + self.errors.push(response_error); + } } } } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 1f1a9ba40..e78d2af91 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -36,7 +36,9 @@ use crate::{ response::project_by_operation, }, response::{ - graphql_error::GraphQLError, merge::deep_merge, subgraph_response::SubgraphResponse, + graphql_error::{GraphQLError, GraphQLErrorPath}, + merge::deep_merge, + subgraph_response::SubgraphResponse, value::Value, }, utils::{ @@ -420,7 +422,7 @@ impl<'exec> Executor<'exec> { ctx: &mut ExecutionContext<'exec>, response_bytes: Bytes, fetch_node_id: i64, - ) -> Option<(Value<'exec>, Option<&'exec Vec>)> { + ) -> Option<(SubgraphResponse<'exec>, Option<&'exec Vec>)> { let idx = ctx.response_storage.add_response(response_bytes); // SAFETY: The `bytes` are transmuted to the lifetime `'a` of the `ExecutionContext`. // This is safe because the `response_storage` is part of the `ExecutionContext` (`ctx`) @@ -452,9 +454,7 @@ impl<'exec> Executor<'exec> { } }; - ctx.handle_errors(response.errors); - - Some((response.data, output_rewrites)) + Some((response, output_rewrites)) } fn process_job_result( @@ -472,16 +472,18 @@ impl<'exec> Executor<'exec> { &mut ctx.response_headers_aggregator, )?; - if let Some((mut data, output_rewrites)) = + if let Some((mut response, output_rewrites)) = self.process_subgraph_response(ctx, job.response.body, job.fetch_node_id) { + ctx.handle_errors(response.errors, None); if let Some(output_rewrites) = output_rewrites { for output_rewrite in output_rewrites { - output_rewrite.rewrite(&self.schema_metadata.possible_types, &mut data); + output_rewrite + .rewrite(&self.schema_metadata.possible_types, &mut response.data); } } - deep_merge(&mut ctx.final_response, data); + deep_merge(&mut ctx.final_response, response.data); } } ExecutionJob::FlattenFetch(job) => { @@ -493,10 +495,10 @@ impl<'exec> Executor<'exec> { &mut ctx.response_headers_aggregator, )?; - if let Some((mut data, output_rewrites)) = + if let Some((mut response, output_rewrites)) = self.process_subgraph_response(ctx, job.response.body, job.fetch_node_id) { - if let Some(mut entities) = data.take_entities() { + if let Some(mut entities) = response.data.take_entities() { if let Some(output_rewrites) = output_rewrites { for output_rewrite in output_rewrites { for entity in &mut entities { @@ -508,15 +510,33 @@ impl<'exec> Executor<'exec> { let mut index = 0; let normalized_path = job.flatten_node_path.as_slice(); + // If there is an error in the response, then collect the paths for normalizing the error + let initial_error_path = response + .errors + .as_ref() + .map(|_| GraphQLErrorPath::with_capacity(normalized_path.len() + 2)); + let mut entity_index_error_map = response + .errors + .as_ref() + .map(|_| HashMap::with_capacity(entities.len())); traverse_and_callback_mut( &mut ctx.final_response, normalized_path, self.schema_metadata, - &mut |target| { + initial_error_path, + &mut |target, error_path| { let hash = job.representation_hashes[index]; if let Some(entity_index) = job.representation_hash_to_index.get(&hash) { + if let (Some(error_path), Some(entity_index_error_map)) = + (error_path, entity_index_error_map.as_mut()) + { + let error_paths = entity_index_error_map + .entry(entity_index) + .or_insert_with(Vec::new); + error_paths.push(error_path); + } if let Some(entity) = entities.get(*entity_index) { // SAFETY: `new_val` is a clone of an entity that lives for `'a`. // The transmute is to satisfy the compiler, but the lifetime @@ -529,6 +549,7 @@ impl<'exec> Executor<'exec> { index += 1; }, ); + ctx.handle_errors(response.errors, entity_index_error_map); } } } @@ -714,6 +735,8 @@ fn select_fetch_variables<'a>( #[cfg(test)] mod tests { + use crate::{context::ExecutionContext, response::graphql_error::GraphQLErrorPath}; + use super::select_fetch_variables; use sonic_rs::Value; use std::collections::{BTreeSet, HashMap}; @@ -768,4 +791,65 @@ mod tests { assert!(selected.is_none()); } + #[test] + /** + * We have the same entity in two different paths ["a", 0] and ["b", 1], + * and the subgraph response has an error for this entity. + * So we should duplicate the error for both paths. + */ + fn normalize_entity_errors_correctly() { + use crate::response::graphql_error::{GraphQLError, GraphQLErrorPathSegment}; + use std::collections::HashMap; + let mut ctx = ExecutionContext::default(); + let mut entity_index_error_map: HashMap<&usize, Vec> = HashMap::new(); + entity_index_error_map.insert( + &0, + vec![ + GraphQLErrorPath { + segments: vec![ + GraphQLErrorPathSegment::String("a".to_string()), + GraphQLErrorPathSegment::Index(0), + ], + }, + GraphQLErrorPath { + segments: vec![ + GraphQLErrorPathSegment::String("b".to_string()), + GraphQLErrorPathSegment::Index(1), + ], + }, + ], + ); + let response_errors = vec![GraphQLError { + message: "Error 1".to_string(), + locations: None, + path: Some(GraphQLErrorPath { + segments: vec![ + GraphQLErrorPathSegment::String("_entities".to_string()), + GraphQLErrorPathSegment::Index(0), + GraphQLErrorPathSegment::String("field1".to_string()), + ], + }), + extensions: None, + }]; + ctx.handle_errors(Some(response_errors), Some(entity_index_error_map)); + assert_eq!(ctx.errors.len(), 2); + assert_eq!(ctx.errors[0].message, "Error 1"); + assert_eq!( + ctx.errors[0].path.as_ref().unwrap().segments, + vec![ + GraphQLErrorPathSegment::String("a".to_string()), + GraphQLErrorPathSegment::Index(0), + GraphQLErrorPathSegment::String("field1".to_string()) + ] + ); + assert_eq!(ctx.errors[1].message, "Error 1"); + assert_eq!( + ctx.errors[1].path.as_ref().unwrap().segments, + vec![ + GraphQLErrorPathSegment::String("b".to_string()), + GraphQLErrorPathSegment::Index(1), + GraphQLErrorPathSegment::String("field1".to_string()) + ] + ); + } } diff --git a/lib/executor/src/response/graphql_error.rs b/lib/executor/src/response/graphql_error.rs index 88b9c6e8e..fd254dbaf 100644 --- a/lib/executor/src/response/graphql_error.rs +++ b/lib/executor/src/response/graphql_error.rs @@ -2,7 +2,7 @@ use graphql_parser::Pos; use graphql_tools::validation::utils::ValidationError; use serde::{de, Deserialize, Deserializer, Serialize}; use sonic_rs::Value; -use std::fmt; +use std::{collections::HashMap, fmt}; #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] @@ -11,7 +11,7 @@ pub struct GraphQLError { #[serde(default, skip_serializing_if = "Option::is_none")] pub locations: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] - pub path: Option>, + pub path: Option, pub extensions: Option, } @@ -46,13 +46,43 @@ impl From<&Pos> for GraphQLErrorLocation { } } +impl GraphQLError { + pub fn entity_index_and_path<'a>(&'a self) -> Option> { + self.path.as_ref().and_then(|p| p.entity_index_and_path()) + } + + pub fn normalize_entity_error( + self, + entity_index_error_map: &HashMap<&usize, Vec>, + ) -> Vec { + if let Some(entity_index_and_path) = &self.entity_index_and_path() { + if let Some(entity_error_paths) = + entity_index_error_map.get(&entity_index_and_path.entity_index) + { + return entity_error_paths + .iter() + .map(|error_path| { + let mut new_error_path = error_path.clone(); + new_error_path.extend_from_slice(entity_index_and_path.rest_of_path); + GraphQLError { + path: Some(new_error_path), + ..self.clone() + } + }) + .collect(); + } + } + vec![self] + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct GraphQLErrorLocation { pub line: usize, pub column: usize, } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, Serialize, PartialEq)] pub enum GraphQLErrorPathSegment { String(String), Index(usize), @@ -110,3 +140,53 @@ impl<'de> Deserialize<'de> for GraphQLErrorPathSegment { deserializer.deserialize_any(PathSegmentVisitor) } } + +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct GraphQLErrorPath { + #[serde(flatten)] + pub segments: Vec, +} + +pub struct EntityIndexAndPath<'a> { + pub entity_index: usize, + pub rest_of_path: &'a [GraphQLErrorPathSegment], +} + +impl GraphQLErrorPath { + pub fn with_capacity(capacity: usize) -> Self { + GraphQLErrorPath { + segments: Vec::with_capacity(capacity), + } + } + pub fn concat(&self, segment: GraphQLErrorPathSegment) -> Self { + let mut new_path = self.segments.clone(); + new_path.push(segment); + GraphQLErrorPath { segments: new_path } + } + + pub fn concat_index(&self, index: usize) -> Self { + self.concat(GraphQLErrorPathSegment::Index(index)) + } + + pub fn concat_str(&self, field: String) -> Self { + self.concat(GraphQLErrorPathSegment::String(field)) + } + + pub fn extend_from_slice(&mut self, other: &[GraphQLErrorPathSegment]) { + self.segments.extend_from_slice(other); + } + + pub fn entity_index_and_path<'a>(&'a self) -> Option> { + match &self.segments.as_slice() { + [GraphQLErrorPathSegment::String(maybe_entities), GraphQLErrorPathSegment::Index(entity_index), rest_of_path @ ..] + if maybe_entities == "_entities" => + { + Some(EntityIndexAndPath { + entity_index: *entity_index, + rest_of_path, + }) + } + _ => None, + } + } +} diff --git a/lib/executor/src/utils/traverse.rs b/lib/executor/src/utils/traverse.rs index 4cb22e98e..d9fcb94b4 100644 --- a/lib/executor/src/utils/traverse.rs +++ b/lib/executor/src/utils/traverse.rs @@ -1,7 +1,8 @@ use hive_router_query_planner::planner::plan_nodes::FlattenNodePathSegment; use crate::{ - introspection::schema::SchemaMetadata, response::value::Value, + introspection::schema::SchemaMetadata, + response::{graphql_error::GraphQLErrorPath, value::Value}, utils::consts::TYPENAME_FIELD_NAME, }; @@ -9,20 +10,24 @@ pub fn traverse_and_callback_mut<'a, Callback>( current_data: &mut Value<'a>, remaining_path: &[FlattenNodePathSegment], schema_metadata: &SchemaMetadata, + current_error_path: Option, callback: &mut Callback, ) where - Callback: FnMut(&mut Value), + Callback: FnMut(&mut Value, Option), { if remaining_path.is_empty() { if let Value::Array(arr) = current_data { // If the path is empty, we call the callback on each item in the array // We iterate because we want the entity objects directly - for item in arr.iter_mut() { - callback(item); + for (index, item) in arr.iter_mut().enumerate() { + let current_error_path_for_index = current_error_path + .as_ref() + .map(|current_error_path| current_error_path.concat_index(index)); + callback(item, current_error_path_for_index); } } else { // If the path is empty and current_data is not an array, just call the callback - callback(current_data); + callback(current_data, current_error_path); } return; } @@ -32,8 +37,17 @@ pub fn traverse_and_callback_mut<'a, Callback>( // If the key is List, we expect current_data to be an array if let Value::Array(arr) = current_data { let rest_of_path = &remaining_path[1..]; - for item in arr.iter_mut() { - traverse_and_callback_mut(item, rest_of_path, schema_metadata, callback); + for (index, item) in arr.iter_mut().enumerate() { + let current_error_path_for_index = current_error_path + .as_ref() + .map(|current_error_path| current_error_path.concat_index(index)); + traverse_and_callback_mut( + item, + rest_of_path, + schema_metadata, + current_error_path_for_index, + callback, + ); } } } @@ -43,7 +57,17 @@ pub fn traverse_and_callback_mut<'a, Callback>( if let Ok(idx) = map.binary_search_by_key(&field_name.as_str(), |(k, _)| k) { let (_, next_data) = map.get_mut(idx).unwrap(); let rest_of_path = &remaining_path[1..]; - traverse_and_callback_mut(next_data, rest_of_path, schema_metadata, callback); + let current_error_path_for_field = + current_error_path.map(|current_error_path| { + current_error_path.concat_str(field_name.clone()) + }); + traverse_and_callback_mut( + next_data, + rest_of_path, + schema_metadata, + current_error_path_for_field, + callback, + ); } } } @@ -64,13 +88,23 @@ pub fn traverse_and_callback_mut<'a, Callback>( current_data, rest_of_path, schema_metadata, + current_error_path, callback, ); } } else if let Value::Array(arr) = current_data { // If the current data is an array, we need to check each item - for item in arr.iter_mut() { - traverse_and_callback_mut(item, remaining_path, schema_metadata, callback); + for (index, item) in arr.iter_mut().enumerate() { + let current_error_path_for_index = current_error_path + .as_ref() + .map(|current_error_path| current_error_path.concat_index(index)); + traverse_and_callback_mut( + item, + remaining_path, + schema_metadata, + current_error_path_for_index, + callback, + ); } } } @@ -139,3 +173,136 @@ where Ok(()) } + +#[cfg(test)] +mod tests { + use hive_router_query_planner::planner::plan_nodes::FlattenNodePathSegment; + + use crate::{ + introspection::schema::SchemaMetadata, + response::{ + graphql_error::{GraphQLErrorPath, GraphQLErrorPathSegment}, + value::Value, + }, + }; + + #[test] + /** + * Collect error paths for each item in a list at one level + * E.g. for data { items: [ {...}, {...} ] } and path ["items", List] + * we should collect paths ["items", 0] and ["items", 1] + */ + fn test_collect_error_paths_one_level() { + let mut data = Value::Object(vec![( + "items", + Value::Array(vec![ + Value::Object(vec![("id", Value::String("1".into()))]), + Value::Object(vec![("id", Value::String("2".into()))]), + ]), + )]); + let path = vec![ + FlattenNodePathSegment::Field("items".into()), + FlattenNodePathSegment::List, + ]; + let mut collected = vec![]; + super::traverse_and_callback_mut( + &mut data, + &path, + &SchemaMetadata::default(), + Some(GraphQLErrorPath::default()), + &mut |_item, error_path| { + collected.push(error_path.unwrap()); + }, + ); + assert_eq!(collected.len(), 2); + assert_eq!( + collected[0].segments, + vec![ + GraphQLErrorPathSegment::String("items".into()), + GraphQLErrorPathSegment::Index(0) + ] + ); + assert_eq!( + collected[1].segments, + vec![ + GraphQLErrorPathSegment::String("items".into()), + GraphQLErrorPathSegment::Index(1) + ] + ); + } + + #[test] + /** + * Collect error paths for each item in a list at two levels + * E.g. for data { users: [ { posts: [ {...}, {...} ] }, { posts: [ {...} ] } ] } and path ["users", List, "posts", List] + * we should collect paths ["users", 0, "posts", 0], ["users", 0, "posts", 1], and ["users", 1, "posts", 0] + */ + fn test_collect_error_paths_two_levels() { + let mut data = Value::Object(vec![( + "users", + Value::Array(vec![ + Value::Object(vec![ + ("id", Value::String("1".into())), + ( + "posts", + Value::Array(vec![ + Value::Object(vec![("id", Value::String("a".into()))]), + Value::Object(vec![("id", Value::String("b".into()))]), + ]), + ), + ]), + Value::Object(vec![ + ("id", Value::String("2".into())), + ( + "posts", + Value::Array(vec![Value::Object(vec![("id", Value::String("c".into()))])]), + ), + ]), + ]), + )]); + let path = vec![ + FlattenNodePathSegment::Field("users".into()), + FlattenNodePathSegment::List, + FlattenNodePathSegment::Field("posts".into()), + FlattenNodePathSegment::List, + ]; + let mut collected = vec![]; + super::traverse_and_callback_mut( + &mut data, + &path, + &SchemaMetadata::default(), + Some(GraphQLErrorPath::default()), + &mut |_item, error_path| { + collected.push(error_path.unwrap()); + }, + ); + assert_eq!(collected.len(), 3); + assert_eq!( + collected[0].segments, + vec![ + GraphQLErrorPathSegment::String("users".into()), + GraphQLErrorPathSegment::Index(0), + GraphQLErrorPathSegment::String("posts".into()), + GraphQLErrorPathSegment::Index(0), + ] + ); + assert_eq!( + collected[1].segments, + vec![ + GraphQLErrorPathSegment::String("users".into()), + GraphQLErrorPathSegment::Index(0), + GraphQLErrorPathSegment::String("posts".into()), + GraphQLErrorPathSegment::Index(1), + ] + ); + assert_eq!( + collected[2].segments, + vec![ + GraphQLErrorPathSegment::String("users".into()), + GraphQLErrorPathSegment::Index(1), + GraphQLErrorPathSegment::String("posts".into()), + GraphQLErrorPathSegment::Index(0), + ] + ); + } +}