Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 57 additions & 109 deletions node-graph/node-macro/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,20 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
let input_ident = &input.pat_ident;

let field_idents: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { pat_ident, .. } | ParsedField::Node { pat_ident, .. } => pat_ident,
})
.collect();
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();

let input_names: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { name, .. } | ParsedField::Node { name, .. } => name,
})
.map(|f| &f.name)
.zip(field_names.iter())
.map(|zipped| match zipped {
(Some(name), _) => name.value(),
(_, name) => name.to_string().to_case(Case::Title),
})
.collect();

let input_descriptions: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { description, .. } | ParsedField::Node { description, .. } => description,
})
.collect();
let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();

let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
quote! { pub(super) #name: #r#gen }
Expand All @@ -84,9 +72,9 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre

let field_types: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { ty, .. } => ty.clone(),
ParsedField::Node { output_type, input_type, .. } => match parsed.is_async {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
ParsedFieldType::Node(NodeParsedField { output_type, input_type, .. }) => match parsed.is_async {
true => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = impl core::future::Future<Output=#output_type>>),
false => parse_quote!(&'n impl #graphene_core::Node<'n, #input_type, Output = #output_type>),
},
Expand All @@ -95,24 +83,18 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre

let widget_override: Vec<_> = fields
.iter()
.map(|field| {
let parsed_widget_override = match field {
ParsedField::Regular { widget_override, .. } => widget_override,
ParsedField::Node { widget_override, .. } => widget_override,
};
match parsed_widget_override {
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden),
ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)),
ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)),
}
.map(|field| match &field.widget_override {
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
ParsedWidgetOverride::Hidden => quote!(RegistryWidgetOverride::Hidden),
ParsedWidgetOverride::String(lit_str) => quote!(RegistryWidgetOverride::String(#lit_str)),
ParsedWidgetOverride::Custom(lit_str) => quote!(RegistryWidgetOverride::Custom(#lit_str)),
})
.collect();

let value_sources: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { value_source, .. } => match value_source {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
ParsedValueSource::Default(data) => quote!(RegistryValueSource::Default(stringify!(#data))),
ParsedValueSource::Scope(data) => quote!(RegistryValueSource::Scope(#data)),
_ => quote!(RegistryValueSource::None),
Expand All @@ -123,8 +105,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre

let default_types: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { implementations, .. } => match implementations.first() {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
Some(ty) => quote!(Some(concrete!(#ty))),
_ => quote!(None),
},
Expand All @@ -134,8 +116,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre

let number_min_values: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { number_soft_min, number_hard_min, .. } => match (number_soft_min, number_hard_min) {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
(Some(soft_min), _) => quote!(Some(#soft_min)),
(None, Some(hard_min)) => quote!(Some(#hard_min)),
(None, None) => quote!(None),
Expand All @@ -145,8 +127,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
.collect();
let number_max_values: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { number_soft_max, number_hard_max, .. } => match (number_soft_max, number_hard_max) {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
(Some(soft_max), _) => quote!(Some(#soft_max)),
(None, Some(hard_max)) => quote!(Some(#hard_max)),
(None, None) => quote!(None),
Expand All @@ -156,77 +138,45 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
.collect();
let number_mode_range_values: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular {
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField {
number_mode_range: Some(number_mode_range),
..
} => quote!(Some(#number_mode_range)),
}) => quote!(Some(#number_mode_range)),
_ => quote!(None),
})
.collect();
let number_display_decimal_places: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular {
number_display_decimal_places: Some(decimal_places),
..
}
| ParsedField::Node {
number_display_decimal_places: Some(decimal_places),
..
} => {
quote!(Some(#decimal_places))
}
_ => quote!(None),
})
.collect();
let number_step: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { number_step: Some(step), .. } | ParsedField::Node { number_step: Some(step), .. } => {
quote!(Some(#step))
}
_ => quote!(None),
})
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
.collect();
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();

let unit_suffix: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { unit: Some(unit), .. } | ParsedField::Node { unit: Some(unit), .. } => {
quote!(Some(#unit))
}
_ => quote!(None),
})
.collect();
let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();

let exposed: Vec<_> = fields
.iter()
.map(|field| match field {
ParsedField::Regular { exposed, .. } => quote!(#exposed),
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
_ => quote!(true),
})
.collect();

let eval_args = fields.iter().map(|field| match field {
ParsedField::Regular { pat_ident, .. } => {
let name = &pat_ident.ident;
quote! { let #name = self.#name.eval(__input.clone()).await; }
}
ParsedField::Node { pat_ident, .. } => {
let name = &pat_ident.ident;
quote! { let #name = &self.#name; }
let eval_args = fields.iter().map(|field| {
let name = &field.pat_ident.ident;
match &field.ty {
ParsedFieldType::Regular { .. } => {
quote! { let #name = self.#name.eval(__input.clone()).await; }
}
ParsedFieldType::Node { .. } => {
quote! { let #name = &self.#name; }
}
}
});

let min_max_args = fields.iter().map(|field| match field {
ParsedField::Regular {
pat_ident,
number_hard_min,
number_hard_max,
..
} => {
let name = &pat_ident.ident;
let min_max_args = fields.iter().map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
let name = &field.pat_ident.ident;
let mut tokens = quote!();
if let Some(min) = number_hard_min {
tokens.extend(quote_spanned! {min.span()=>
Expand All @@ -241,15 +191,13 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
}
tokens
}
ParsedField::Node { .. } => {
quote!()
}
ParsedFieldType::Node { .. } => quote!(),
});

let all_implementation_types = fields.iter().flat_map(|field| match field {
ParsedField::Regular { implementations, .. } => implementations.into_iter().cloned().collect::<Vec<_>>(),
ParsedField::Node { implementations, .. } => implementations
.into_iter()
let all_implementation_types = fields.iter().flat_map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => implementations.iter().cloned().collect::<Vec<_>>(),
ParsedFieldType::Node(NodeParsedField { implementations, .. }) => implementations
.iter()
.flat_map(|implementation| [implementation.input.clone(), implementation.output.clone()])
.collect(),
});
Expand All @@ -260,11 +208,11 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
let mut clampable_clauses = Vec::new();

for (field, name) in fields.iter().zip(struct_generics.iter()) {
clauses.push(match (field, *is_async) {
clauses.push(match (&field.ty, *is_async) {
(
ParsedField::Regular {
ParsedFieldType::Regular(RegularParsedField {
ty, number_hard_min, number_hard_max, ..
},
}),
_,
) => {
let all_lifetime_ty = substitute_lifetimes(ty.clone(), "all");
Expand All @@ -284,7 +232,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
#name: #graphene_core::Node<'n, #input_type, Output = #fut_ident> + #graphene_core::WasmNotSync
)
}
(ParsedField::Node { input_type, output_type, .. }, true) => {
(ParsedFieldType::Node(NodeParsedField { input_type, output_type, .. }), true) => {
let id = future_idents.len();
let fut_ident = format_ident!("F{}", id);
future_idents.push(fut_ident.clone());
Expand All @@ -294,7 +242,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
#name: #graphene_core::Node<'n, #input_type, Output = #fut_ident > + #graphene_core::WasmNotSync
)
}
(ParsedField::Node { .. }, false) => unreachable!(),
(ParsedFieldType::Node { .. }, false) => unreachable!(),
});
}
let where_clause = where_clause.clone().unwrap_or(WhereClause {
Expand Down Expand Up @@ -454,9 +402,9 @@ fn generate_node_input_references(
let (mut modified, mut generic_collector) = FilterUsedGenerics::new(fn_generics);

for (input_index, (parsed_input, input_ident)) in parsed.fields.iter().zip(field_idents).enumerate() {
let mut ty = match parsed_input {
ParsedField::Regular { ty, .. } => ty,
ParsedField::Node { output_type, .. } => output_type,
let mut ty = match &parsed_input.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
ParsedFieldType::Node(NodeParsedField { output_type, .. }) => output_type,
}
.clone();

Expand Down Expand Up @@ -540,20 +488,20 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
.fields
.iter()
.map(|field| {
match field {
ParsedField::Regular { implementations, ty, .. } => {
match &field.ty {
ParsedFieldType::Regular(RegularParsedField { implementations, ty, .. }) => {
if !implementations.is_empty() {
implementations.iter().map(|ty| (&unit, ty)).collect()
} else {
vec![(&unit, ty)]
}
}
ParsedField::Node {
ParsedFieldType::Node(NodeParsedField {
implementations,
input_type,
output_type,
..
} => {
}) => {
if !implementations.is_empty() {
implementations.iter().map(|impl_| (&impl_.input, &impl_.output)).collect()
} else {
Expand All @@ -578,7 +526,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
let field_name = field_names[j];
let (input_type, output_type) = &types[i.min(types.len() - 1)];

let node = matches!(parsed.fields[j], ParsedField::Node { .. });
let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });

let downcast_node = quote!(
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
Expand Down
Loading