Skip to content

Commit 4ae6165

Browse files
committed
Scan directives arguments while parsing schema
1 parent 7fb45d6 commit 4ae6165

File tree

4 files changed

+223
-77
lines changed

4 files changed

+223
-77
lines changed

src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ internal class SchemaClassScanner(
149149
?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName")
150150
when (typeDefinition) {
151151
is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition)
152-
is InputObjectTypeDefinition -> {
153-
for (input in typeDefinition.inputValueDefinitions) {
154-
handleDirectiveInput(input.type)
155-
}
152+
is EnumTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
153+
"Enum type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
154+
}
155+
is InputObjectTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
156+
"Input object type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
156157
}
157158
}
158159
}
@@ -209,9 +210,9 @@ internal class SchemaClassScanner(
209210
log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}")
210211
}
211212

212-
val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } }
213-
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<NormalResolverInfo>()
214-
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }
213+
val fieldResolvers = fieldResolversByType.flatMap { entry -> entry.value.map { it.value } }
214+
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<NormalResolverInfo>().toSet()
215+
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }.toSet()
215216

216217
(resolverInfos - observedNormalResolverInfos - observedMultiResolverInfos).forEach { resolverInfo ->
217218
log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}")
@@ -255,7 +256,7 @@ internal class SchemaClassScanner(
255256
}.flatten().distinct()
256257
}
257258

258-
private fun handleDictionaryTypes(types: List<ObjectTypeDefinition>, failureMessage: (ObjectTypeDefinition) -> String) {
259+
private fun handleDictionaryTypes(types: List<TypeDefinition<*>>, failureMessage: (TypeDefinition<*>) -> String) {
259260
types.forEach { type ->
260261
val dictionaryContainsType = dictionary.filter { it.key.name == type.name }.isNotEmpty()
261262
if (!unvalidatedTypes.contains(type) && !dictionaryContainsType) {

src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt

Lines changed: 71 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package graphql.kickstart.tools
22

3-
import graphql.Scalars
43
import graphql.introspection.Introspection
54
import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION
65
import graphql.kickstart.tools.directive.DirectiveWiringHelper
@@ -9,6 +8,7 @@ import graphql.kickstart.tools.util.getExtendedFieldDefinitions
98
import graphql.kickstart.tools.util.unwrap
109
import graphql.language.*
1110
import graphql.schema.*
11+
import graphql.schema.idl.DirectiveInfo
1212
import graphql.schema.idl.RuntimeWiring
1313
import graphql.schema.idl.ScalarInfo
1414
import graphql.schema.visibility.NoIntrospectionGraphqlFieldVisibility
@@ -60,6 +60,8 @@ class SchemaParser internal constructor(
6060
private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry()
6161
private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions)
6262

63+
private lateinit var schemaDirectives : Set<GraphQLDirective>
64+
6365
/**
6466
* Parses the given schema with respect to the given dictionary and returns GraphQL objects.
6567
*/
@@ -72,6 +74,7 @@ class SchemaParser internal constructor(
7274

7375
// Create GraphQL objects
7476
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
77+
schemaDirectives = createDirectives(inputObjects)
7578
inputObjectDefinitions.forEach {
7679
if (inputObjects.none { io -> io.name == it.name }) {
7780
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
@@ -82,8 +85,6 @@ class SchemaParser internal constructor(
8285
val unions = unionDefinitions.map { createUnionObject(it, objects) }
8386
val enums = enumDefinitions.map { createEnumObject(it) }
8487

85-
val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet()
86-
8788
// Assign type resolver to interfaces now that we know all of the object types
8889
interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) }
8990
unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) }
@@ -103,7 +104,7 @@ class SchemaParser internal constructor(
103104
val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation }
104105

105106
val types = (additionalObjects.toSet() as Set<GraphQLType>) + inputObjects + enums + interfaces + unions
106-
return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription())
107+
return SchemaObjects(query, mutation, subscription, types, schemaDirectives, codeRegistryBuilder, rootInfo.getDescription())
107108
}
108109

109110
/**
@@ -300,44 +301,77 @@ class SchemaParser internal constructor(
300301
.name(definition.name)
301302
.definition(definition)
302303
.description(getDocumentation(definition, options))
303-
.type(determineInputType(definition.type, inputObjects, setOf()))
304+
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
304305
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
305306
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
306307
.withAppliedDirectives(*buildAppliedDirectives(definition.directives))
307308
.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
308309
.build()
309310
}
310311

311-
private fun createDirective(definition: DirectiveDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLDirective {
312-
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
312+
private fun createDirectives(inputObjects: MutableList<GraphQLInputObjectType>): Set<GraphQLDirective> {
313+
schemaDirectives = directiveDefinitions.map { definition ->
314+
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
315+
316+
GraphQLDirective.newDirective()
317+
.name(definition.name)
318+
.description(getDocumentation(definition, options))
319+
.definition(definition)
320+
.comparatorRegistry(runtimeWiring.comparatorRegistry)
321+
.validLocations(*locations)
322+
.repeatable(definition.isRepeatable)
323+
.apply {
324+
definition.inputValueDefinitions.forEach { argumentDefinition ->
325+
argument(createDirectiveArgument(argumentDefinition, inputObjects))
326+
}
327+
}
328+
.build()
329+
}.toSet()
330+
// because the arguments can have directives too, we attach them only after the directives themselves are created
331+
schemaDirectives = schemaDirectives.map { d ->
332+
val arguments = d.arguments.map { a -> a.transform {
333+
it.withAppliedDirectives(*buildAppliedDirectives(a.definition!!.directives))
334+
.withDirectives(*buildDirectives(a.definition!!.directives, Introspection.DirectiveLocation.OBJECT))
335+
} }
336+
d.transform { it.replaceArguments(arguments) }
337+
}.toSet()
338+
339+
return schemaDirectives
340+
}
313341

314-
return GraphQLDirective.newDirective()
342+
private fun createDirectiveArgument(definition: InputValueDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLArgument {
343+
return GraphQLArgument.newArgument()
315344
.name(definition.name)
316-
.description(getDocumentation(definition, options))
317345
.definition(definition)
318-
.comparatorRegistry(runtimeWiring.comparatorRegistry)
319-
.validLocations(*locations)
320-
.repeatable(definition.isRepeatable)
321-
.apply {
322-
definition.inputValueDefinitions.forEach { argumentDefinition ->
323-
argument(createArgument(argumentDefinition, inputObjects))
324-
}
325-
}
346+
.description(getDocumentation(definition, options))
347+
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
348+
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
349+
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
326350
.build()
327351
}
328352

329353
private fun buildAppliedDirectives(directives: List<Directive>): Array<GraphQLAppliedDirective> {
330-
return directives.map {
354+
return directives.map { directive ->
355+
val graphQLDirective = schemaDirectives.find { d -> d.name == directive.name }
356+
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
357+
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
358+
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
359+
331360
GraphQLAppliedDirective.newDirective()
332-
.name(it.name)
333-
.description(getDocumentation(it, options))
361+
.name(directive.name)
362+
.description(getDocumentation(directive, options))
363+
.definition(directive)
334364
.comparatorRegistry(runtimeWiring.comparatorRegistry)
335365
.apply {
336-
it.arguments.forEach { arg ->
366+
directive.arguments.forEach { arg ->
367+
val graphQLArgument = graphQLArguments[arg.name]
368+
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name} .")
337369
argument(GraphQLAppliedDirectiveArgument.newArgument()
338370
.name(arg.name)
339-
.type(buildDirectiveInputType(arg.value))
371+
// TODO instead of guessing the type from its value, lookup the directive definition
372+
.type(graphQLArgument.type)
340373
.valueLiteral(arg.value)
374+
.description(graphQLArgument.description)
341375
.build()
342376
)
343377
}
@@ -358,6 +392,10 @@ class SchemaParser internal constructor(
358392
val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false
359393
if (repeatable || !names.contains(directive.name)) {
360394
names.add(directive.name)
395+
val graphQLDirective = this.schemaDirectives.find { d -> d.name == directive.name }
396+
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
397+
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
398+
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
361399
output.add(
362400
GraphQLDirective.newDirective()
363401
.name(directive.name)
@@ -367,9 +405,11 @@ class SchemaParser internal constructor(
367405
.repeatable(repeatable)
368406
.apply {
369407
directive.arguments.forEach { arg ->
408+
val graphQLArgument = graphQLArguments[arg.name]
409+
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name}.")
370410
argument(GraphQLArgument.newArgument()
371411
.name(arg.name)
372-
.type(buildDirectiveInputType(arg.value))
412+
.type(graphQLArgument.type)
373413
// TODO remove this once directives are fully replaced with applied directives
374414
.valueLiteral(arg.value)
375415
.build())
@@ -383,46 +423,6 @@ class SchemaParser internal constructor(
383423
return output.toTypedArray()
384424
}
385425

386-
private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? {
387-
return when (value) {
388-
is NullValue -> Scalars.GraphQLString
389-
is FloatValue -> Scalars.GraphQLFloat
390-
is StringValue -> Scalars.GraphQLString
391-
is IntValue -> Scalars.GraphQLInt
392-
is BooleanValue -> Scalars.GraphQLBoolean
393-
is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value)))
394-
// TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?)
395-
else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.")
396-
}
397-
}
398-
399-
private fun getArrayValueWrappedType(value: ArrayValue): Value<*> {
400-
// empty array [] is equivalent to [null]
401-
if (value.values.isEmpty()) {
402-
return NullValue.newNullValue().build()
403-
}
404-
405-
// get rid of null values
406-
val nonNullValueList = value.values.filter { v -> v !is NullValue }
407-
408-
// [null, null, ...] unwrapped is null
409-
if (nonNullValueList.isEmpty()) {
410-
return NullValue.newNullValue().build()
411-
}
412-
413-
// make sure the array isn't polymorphic
414-
val distinctTypes = nonNullValueList
415-
.map { it::class.java }
416-
.distinct()
417-
418-
if (distinctTypes.size > 1) {
419-
throw SchemaError("Arrays containing multiple types of values are not supported yet.")
420-
}
421-
422-
// peek at first value, value exists and is assured to be non-null
423-
return nonNullValueList[0]
424-
}
425-
426426
private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
427427
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType
428428

@@ -455,13 +455,15 @@ class SchemaParser internal constructor(
455455
else -> throw SchemaError("Unknown type: $typeDefinition")
456456
}
457457

458-
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
458+
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: MutableSet<String>) =
459459
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects)
460460

461-
private fun <T : Any> determineInputType(expectedType: KClass<T>,
462-
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
463-
inputObjects: List<GraphQLInputObjectType>,
464-
referencingInputObjects: Set<String>): GraphQLInputType =
461+
private fun <T : Any> determineInputType(
462+
expectedType: KClass<T>,
463+
typeDefinition: Type<*>,
464+
allowedTypeReferences: Set<String>,
465+
inputObjects: List<GraphQLInputObjectType>,
466+
referencingInputObjects: MutableSet<String>): GraphQLInputType =
465467
when (typeDefinition) {
466468
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
467469
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
@@ -489,7 +491,7 @@ class SchemaParser internal constructor(
489491
if (referencingInputObject != null) {
490492
GraphQLTypeReference(referencingInputObject)
491493
} else {
492-
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
494+
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects)
493495
(inputObjects as MutableList).add(inputObject)
494496
inputObject
495497
}

0 commit comments

Comments
 (0)