diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt index 650cb679..239311bb 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt @@ -103,6 +103,8 @@ internal class SchemaClassScanner( } while (scanQueue()) } + handleDirectives() + return validateAndCreateResult(rootTypeHolder) } @@ -131,6 +133,30 @@ internal class SchemaClassScanner( scanResolverInfoForPotentialMatches(rootType.type, rootType.resolverInfo) } + + private fun handleDirectives() { + for (directive in directiveDefinitions) { + for (input in directive.inputValueDefinitions) { + handleDirectiveInput(input.type) + } + } + } + + private fun handleDirectiveInput(inputType: Type<*>) { + val inputTypeName = (inputType.unwrap() as TypeName).name + val typeDefinition = ScalarInfo.GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS[inputTypeName] + ?: definitionsByName[inputTypeName] + ?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName") + when (typeDefinition) { + is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition) + is InputObjectTypeDefinition -> { + for (input in typeDefinition.inputValueDefinitions) { + handleDirectiveInput(input.type) + } + } + } + } + private fun validateAndCreateResult(rootTypeHolder: RootTypesHolder): ScannedSchemaObjects { initialDictionary .filter { !it.value.accessed } @@ -280,16 +306,9 @@ internal class SchemaClassScanner( } } - private fun handleFoundType(match: TypeClassMatcher.Match) { - when (match) { - is TypeClassMatcher.ScalarMatch -> { - handleFoundScalarType(match.type) - } - - is TypeClassMatcher.ValidMatch -> { - handleFoundType(match.type, match.javaType, match.reference) - } - } + private fun handleFoundType(match: TypeClassMatcher.Match) = when (match) { + is TypeClassMatcher.ScalarMatch -> handleFoundScalarType(match.type) + is TypeClassMatcher.ValidMatch -> handleFoundType(match.type, match.javaType, match.reference) } private fun handleFoundScalarType(type: ScalarTypeDefinition) { @@ -392,12 +411,10 @@ internal class SchemaClassScanner( val filteredMethods = methods.filter { it.name == name || it.name == "get${name.replaceFirstChar(Char::titlecase)}" }.sortedBy { it.name.length } - return filteredMethods.find { - !it.isSynthetic - }?.genericReturnType ?: filteredMethods.firstOrNull( - )?.genericReturnType ?: clazz.fields.find { - it.name == name - }?.genericType + + return filteredMethods.find { !it.isSynthetic }?.genericReturnType + ?: filteredMethods.firstOrNull()?.genericReturnType + ?: clazz.fields.find { it.name == name }?.genericType } private data class QueueItem(val type: ObjectTypeDefinition, val clazz: JavaType) diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index 7772f7ab..f48ced77 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -1,5 +1,6 @@ package graphql.kickstart.tools +import graphql.Scalars import graphql.introspection.Introspection import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION import graphql.kickstart.tools.directive.DirectiveWiringHelper @@ -335,7 +336,7 @@ class SchemaParser internal constructor( it.arguments.forEach { arg -> argument(GraphQLAppliedDirectiveArgument.newArgument() .name(arg.name) - .type(directiveWiringHelper.buildDirectiveInputType(arg.value)) + .type(buildDirectiveInputType(arg.value)) .valueLiteral(arg.value) .build() ) @@ -350,7 +351,76 @@ class SchemaParser internal constructor( directives: List, directiveLocation: Introspection.DirectiveLocation ): Array { - return directiveWiringHelper.buildDirectives(directives, directiveLocation).toTypedArray() + val names = mutableSetOf() + val output = mutableListOf() + + for (directive in directives) { + val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false + if (repeatable || !names.contains(directive.name)) { + names.add(directive.name) + output.add( + GraphQLDirective.newDirective() + .name(directive.name) + .description(getDocumentation(directive, options)) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .validLocation(directiveLocation) + .repeatable(repeatable) + .apply { + directive.arguments.forEach { arg -> + argument(GraphQLArgument.newArgument() + .name(arg.name) + .type(buildDirectiveInputType(arg.value)) + // TODO remove this once directives are fully replaced with applied directives + .valueLiteral(arg.value) + .build()) + } + } + .build() + ) + } + } + + return output.toTypedArray() + } + + private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { + return when (value) { + is NullValue -> Scalars.GraphQLString + is FloatValue -> Scalars.GraphQLFloat + is StringValue -> Scalars.GraphQLString + is IntValue -> Scalars.GraphQLInt + is BooleanValue -> Scalars.GraphQLBoolean + is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) + // TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?) + else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") + } + } + + private fun getArrayValueWrappedType(value: ArrayValue): Value<*> { + // empty array [] is equivalent to [null] + if (value.values.isEmpty()) { + return NullValue.newNullValue().build() + } + + // get rid of null values + val nonNullValueList = value.values.filter { v -> v !is NullValue } + + // [null, null, ...] unwrapped is null + if (nonNullValueList.isEmpty()) { + return NullValue.newNullValue().build() + } + + // make sure the array isn't polymorphic + val distinctTypes = nonNullValueList + .map { it::class.java } + .distinct() + + if (distinctTypes.size > 1) { + throw SchemaError("Arrays containing multiple types of values are not supported yet.") + } + + // peek at first value, value exists and is assured to be non-null + return nonNullValueList[0] } private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List) = diff --git a/src/main/kotlin/graphql/kickstart/tools/TypeClassMatcher.kt b/src/main/kotlin/graphql/kickstart/tools/TypeClassMatcher.kt index 4ff5e451..324003f2 100644 --- a/src/main/kotlin/graphql/kickstart/tools/TypeClassMatcher.kt +++ b/src/main/kotlin/graphql/kickstart/tools/TypeClassMatcher.kt @@ -93,7 +93,7 @@ internal class TypeClassMatcher(private val definitionsByName: Map wireDirectives(wrapper: WiringWrapper): T { - val directivesContainer = wrapper.graphQlType.definition as DirectivesContainer<*> - val directives = buildDirectives(directivesContainer.directives, wrapper.directiveLocation) - val directivesByName = directives.associateBy { it.name } var output = wrapper.graphQlType // first the specific named directives wrapper.graphQlType.appliedDirectives.forEach { appliedDirective -> - val env = buildEnvironment(wrapper, directives, directivesByName[appliedDirective.name], appliedDirective) + val env = buildEnvironment(wrapper, appliedDirective) val wiring = runtimeWiring.registeredDirectiveWiring[appliedDirective.name] wiring?.let { output = wrapper.invoker(it, env) } } // now call any statically added to the runtime runtimeWiring.directiveWiring.forEach { staticWiring -> - val env = buildEnvironment(wrapper, directives, null, null) + val env = buildEnvironment(wrapper) output = wrapper.invoker(staticWiring, env) } // wiring factory is last (if present) - val env = buildEnvironment(wrapper, directives, null, null) + val env = buildEnvironment(wrapper) if (runtimeWiring.wiringFactory.providesSchemaDirectiveWiring(env)) { val factoryWiring = runtimeWiring.wiringFactory.getSchemaDirectiveWiring(env) output = wrapper.invoker(factoryWiring, env) @@ -100,46 +96,15 @@ class DirectiveWiringHelper( return output } - fun buildDirectives(directives: List, directiveLocation: Introspection.DirectiveLocation): List { - val names = mutableSetOf() - val output = mutableListOf() - - for (directive in directives) { - val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false - if (repeatable || !names.contains(directive.name)) { - names.add(directive.name) - output.add( - GraphQLDirective.newDirective() - .name(directive.name) - .description(getDocumentation(directive, options)) - .comparatorRegistry(runtimeWiring.comparatorRegistry) - .validLocation(directiveLocation) - .repeatable(repeatable) - .apply { - directive.arguments.forEach { arg -> - argument(GraphQLArgument.newArgument() - .name(arg.name) - .type(buildDirectiveInputType(arg.value)) - // TODO remove this once directives are fully replaced with applied directives - .valueLiteral(arg.value) - .build()) - } - } - .build() - ) - } - } - - return output - } - - private fun buildEnvironment(wrapper: WiringWrapper, directives: List, directive: GraphQLDirective?, appliedDirective: GraphQLAppliedDirective?): SchemaDirectiveWiringEnvironmentImpl { + private fun buildEnvironment(wrapper: WiringWrapper, appliedDirective: GraphQLAppliedDirective? = null): SchemaDirectiveWiringEnvironmentImpl { + val type = wrapper.graphQlType + val directive = appliedDirective?.let { d -> type.directives.find { it.name == d.name } } val nodeParentTree = buildAstTree(*listOfNotNull( wrapper.fieldsContainer?.definition, wrapper.inputFieldsContainer?.definition, wrapper.enumType?.definition, wrapper.fieldDefinition?.definition, - wrapper.graphQlType.definition + type.definition ).filterIsInstance>() .toTypedArray()) val elementParentTree = buildRuntimeTree(*listOfNotNull( @@ -147,55 +112,16 @@ class DirectiveWiringHelper( wrapper.inputFieldsContainer, wrapper.enumType, wrapper.fieldDefinition, - wrapper.graphQlType + type ).toTypedArray()) - val params = when (wrapper.graphQlType) { - is GraphQLFieldDefinition -> schemaDirectiveParameters.newParams(wrapper.graphQlType, wrapper.fieldsContainer, nodeParentTree, elementParentTree) + val params = when (type) { + is GraphQLFieldDefinition -> schemaDirectiveParameters.newParams(type, wrapper.fieldsContainer, nodeParentTree, elementParentTree) is GraphQLArgument -> schemaDirectiveParameters.newParams(wrapper.fieldDefinition, wrapper.fieldsContainer, nodeParentTree, elementParentTree) // object or interface - is GraphQLFieldsContainer -> schemaDirectiveParameters.newParams(wrapper.graphQlType, nodeParentTree, elementParentTree) + is GraphQLFieldsContainer -> schemaDirectiveParameters.newParams(type, nodeParentTree, elementParentTree) else -> schemaDirectiveParameters.newParams(nodeParentTree, elementParentTree) } - return SchemaDirectiveWiringEnvironmentImpl(wrapper.graphQlType, directives, wrapper.graphQlType.appliedDirectives, directive, appliedDirective, params) - } - - fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { - return when (value) { - is NullValue -> Scalars.GraphQLString - is FloatValue -> Scalars.GraphQLFloat - is StringValue -> Scalars.GraphQLString - is IntValue -> Scalars.GraphQLInt - is BooleanValue -> Scalars.GraphQLBoolean - is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) - else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") - } - } - - private fun getArrayValueWrappedType(value: ArrayValue): Value<*> { - // empty array [] is equivalent to [null] - if (value.values.isEmpty()) { - return NullValue.newNullValue().build() - } - - // get rid of null values - val nonNullValueList = value.values.filter { v -> v !is NullValue } - - // [null, null, ...] unwrapped is null - if (nonNullValueList.isEmpty()) { - return NullValue.newNullValue().build() - } - - // make sure the array isn't polymorphic - val distinctTypes = nonNullValueList - .map { it::class.java } - .distinct() - - if (distinctTypes.size > 1) { - throw SchemaError("Arrays containing multiple types of values are not supported yet.") - } - - // peek at first value, value exists and is assured to be non-null - return nonNullValueList[0] + return SchemaDirectiveWiringEnvironmentImpl(type, type.directives, type.appliedDirectives, directive, appliedDirective, params) } private fun buildAstTree(vararg nodes: NamedNode<*>): NodeParentTree> { diff --git a/src/test/kotlin/graphql/kickstart/tools/InaccessibleFieldResolverTest.kt b/src/test/kotlin/graphql/kickstart/tools/InaccessibleFieldResolverTest.kt index 8c78ef04..dc0d1814 100644 --- a/src/test/kotlin/graphql/kickstart/tools/InaccessibleFieldResolverTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/InaccessibleFieldResolverTest.kt @@ -4,7 +4,6 @@ import graphql.ExceptionWhileDataFetching import graphql.GraphQL import graphql.execution.AsyncExecutionStrategy import graphql.schema.GraphQLSchema -import org.junit.Ignore import org.junit.Test import java.util.* @@ -16,7 +15,6 @@ import java.util.* class InaccessibleFieldResolverTest { @Test - @Ignore // TODO enable test after upgrading to 17 fun `private field from closed module is not accessible`() { val schema: GraphQLSchema = SchemaParser.newParser() .schemaString( diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt index 9e96d925..24ec19b2 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt @@ -5,7 +5,6 @@ import graphql.execution.CoercedVariables import graphql.language.Value import graphql.schema.* import kotlinx.coroutines.ExperimentalCoroutinesApi -import org.junit.Ignore import org.junit.Test import java.util.* import java.util.concurrent.CompletableFuture @@ -425,13 +424,18 @@ class SchemaClassScannerTest { } @Test - @Ignore("TODO remove this once directives are fully replaced with applied directives OR issue #664 is resolved") fun `scanner should handle unused types when option is true`() { val schema = SchemaParser.newParser() .schemaString( """ - # Let's say this is the Products service from Apollo Federation Introduction + # these directives are defined in the Apollo Federation Specification: + # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ + scalar FieldSet + directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE + directive @extends on OBJECT | INTERFACE + directive @external on FIELD_DEFINITION | OBJECT + # Let's say this is the Products service from Apollo Federation Introduction type Query { allProducts: [Product] } @@ -440,8 +444,6 @@ class SchemaClassScannerTest { name: String } - # these directives are defined in the Apollo Federation Specification: - # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ type User @key(fields: "id") @extends { id: ID! @external recentPurchasedProducts: [Product] @@ -457,6 +459,7 @@ class SchemaClassScannerTest { }) .options(SchemaParserOptions.newOptions().includeUnusedTypes(true).build()) .dictionary(User::class) + .scalars(fieldSetScalar) .build() .makeExecutableSchema() @@ -465,6 +468,19 @@ class SchemaClassScannerTest { assert(objectTypes.any { it.name == "Address" }) } + data class FieldSet(val value: String) + + private val fieldSetScalar: GraphQLScalarType = GraphQLScalarType.newScalar() + .name("FieldSet") + .coercing(object : Coercing { + override fun serialize(input: Any, context: GraphQLContext, locale: Locale) = input.toString() + override fun parseValue(input: Any, context: GraphQLContext, locale: Locale) = + FieldSet(input.toString()) + override fun parseLiteral(input: Value<*>, variables: CoercedVariables, context: GraphQLContext, locale: Locale) = + FieldSet(input.toString()) + }) + .build() + class Product { var name: String? = null }