Skip to content

Commit a6393e0

Browse files
committed
Fix for #409
1 parent 313afc5 commit a6393e0

File tree

2 files changed

+67
-12
lines changed

2 files changed

+67
-12
lines changed

src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class SchemaParser internal constructor(
8181
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
8282
inputObjectDefinitions.forEach {
8383
if (inputObjects.none { io -> io.name == it.name }) {
84-
inputObjects.add(createInputObject(it, inputObjects))
84+
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
8585
}
8686
}
8787
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
@@ -173,7 +173,8 @@ class SchemaParser internal constructor(
173173
return output.toTypedArray()
174174
}
175175

176-
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInputObjectType {
176+
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>,
177+
referencingInputObjects: MutableSet<String>): GraphQLInputObjectType {
177178
val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }
178179

179180
val builder = GraphQLInputObjectType.newInputObject()
@@ -184,14 +185,16 @@ class SchemaParser internal constructor(
184185

185186
builder.withDirectives(*buildDirectives(definition.directives, setOf(), Introspection.DirectiveLocation.INPUT_OBJECT))
186187

188+
referencingInputObjects.add(definition.name)
189+
187190
(extensionDefinitions + definition).forEach {
188191
it.inputValueDefinitions.forEach { inputDefinition ->
189192
val fieldBuilder = GraphQLInputObjectField.newInputObjectField()
190193
.name(inputDefinition.name)
191194
.definition(inputDefinition)
192195
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
193196
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
194-
.type(determineInputType(inputDefinition.type, inputObjects))
197+
.type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects))
195198
.withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
196199
builder.field(fieldBuilder.build())
197200
}
@@ -297,7 +300,7 @@ class SchemaParser internal constructor(
297300
.definition(argumentDefinition)
298301
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
299302
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
300-
.type(determineInputType(argumentDefinition.type, inputObjects))
303+
.type(determineInputType(argumentDefinition.type, inputObjects, setOf()))
301304
.withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
302305
field.argument(argumentBuilder.build())
303306
}
@@ -328,7 +331,7 @@ class SchemaParser internal constructor(
328331
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
329332
is InputObjectTypeDefinition -> {
330333
log.info("Create input object")
331-
createInputObject(typeDefinition, inputObjects)
334+
createInputObject(typeDefinition, inputObjects, mutableSetOf())
332335
}
333336
is TypeName -> {
334337
val scalarType = customScalars[typeDefinition.name]
@@ -346,16 +349,19 @@ class SchemaParser internal constructor(
346349
else -> throw SchemaError("Unknown type: $typeDefinition")
347350
}
348351

349-
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
350-
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
352+
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
353+
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) as GraphQLInputType
351354

352-
private fun <T : Any> determineInputType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
355+
private fun <T : Any> determineInputType(expectedType: KClass<T>,
356+
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
357+
inputObjects: List<GraphQLInputObjectType>,
358+
referencingInputObjects: Set<String>): GraphQLType =
353359
when (typeDefinition) {
354360
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
355361
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
356362
is InputObjectTypeDefinition -> {
357363
log.info("Create input object")
358-
createInputObject(typeDefinition, inputObjects)
364+
createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet<String>)
359365
}
360366
is TypeName -> {
361367
val scalarType = customScalars[typeDefinition.name]
@@ -373,9 +379,14 @@ class SchemaParser internal constructor(
373379
} else {
374380
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
375381
if (filteredDefinitions.isNotEmpty()) {
376-
val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
377-
(inputObjects as MutableList).add(inputObject)
378-
inputObject
382+
val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
383+
if (referencingInputObject != null) {
384+
GraphQLTypeReference(referencingInputObject)
385+
} else {
386+
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
387+
(inputObjects as MutableList).add(inputObject)
388+
inputObject
389+
}
379390
} else {
380391
// todo: handle enum type
381392
GraphQLTypeReference(typeDefinition.name)

src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,50 @@ class SchemaParserSpec extends Specification {
368368
noExceptionThrown()
369369
}
370370

371+
def "allow circular relations in input objects"() {
372+
when:
373+
SchemaParser.newParser().schemaString('''\
374+
input A {
375+
id: ID!
376+
b: B
377+
}
378+
input B {
379+
id: ID!
380+
a: A
381+
}
382+
input C {
383+
id: ID!
384+
c: C
385+
}
386+
type Query {}
387+
type Mutation {
388+
test(input: A!): Boolean
389+
testC(input: C!): Boolean
390+
}
391+
'''.stripIndent())
392+
.resolvers(new GraphQLMutationResolver() {
393+
static class A {
394+
String id;
395+
B b;
396+
}
397+
static class B {
398+
String id;
399+
A a;
400+
}
401+
static class C {
402+
String id;
403+
C c;
404+
}
405+
boolean test(A a) { return true }
406+
boolean testC(C c) { return true }
407+
}, new GraphQLQueryResolver() {})
408+
.build()
409+
.makeExecutableSchema()
410+
411+
then:
412+
noExceptionThrown()
413+
}
414+
371415
enum EnumType {
372416
TEST
373417
}

0 commit comments

Comments
 (0)