@@ -81,7 +81,7 @@ class SchemaParser internal constructor(
81
81
val inputObjects: MutableList <GraphQLInputObjectType > = mutableListOf ()
82
82
inputObjectDefinitions.forEach {
83
83
if (inputObjects.none { io -> io.name == it.name }) {
84
- inputObjects.add(createInputObject(it, inputObjects))
84
+ inputObjects.add(createInputObject(it, inputObjects, mutableSetOf () ))
85
85
}
86
86
}
87
87
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
@@ -173,7 +173,8 @@ class SchemaParser internal constructor(
173
173
return output.toTypedArray()
174
174
}
175
175
176
- private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >): GraphQLInputObjectType {
176
+ private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >,
177
+ referencingInputObjects : MutableSet <String >): GraphQLInputObjectType {
177
178
val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }
178
179
179
180
val builder = GraphQLInputObjectType .newInputObject()
@@ -184,14 +185,16 @@ class SchemaParser internal constructor(
184
185
185
186
builder.withDirectives(* buildDirectives(definition.directives, setOf (), Introspection .DirectiveLocation .INPUT_OBJECT ))
186
187
188
+ referencingInputObjects.add(definition.name)
189
+
187
190
(extensionDefinitions + definition).forEach {
188
191
it.inputValueDefinitions.forEach { inputDefinition ->
189
192
val fieldBuilder = GraphQLInputObjectField .newInputObjectField()
190
193
.name(inputDefinition.name)
191
194
.definition(inputDefinition)
192
195
.description(if (inputDefinition.description != null ) inputDefinition.description.content else getDocumentation(inputDefinition))
193
196
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
194
- .type(determineInputType(inputDefinition.type, inputObjects))
197
+ .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects ))
195
198
.withDirectives(* buildDirectives(inputDefinition.directives, setOf (), Introspection .DirectiveLocation .INPUT_FIELD_DEFINITION ))
196
199
builder.field(fieldBuilder.build())
197
200
}
@@ -297,7 +300,7 @@ class SchemaParser internal constructor(
297
300
.definition(argumentDefinition)
298
301
.description(if (argumentDefinition.description != null ) argumentDefinition.description.content else getDocumentation(argumentDefinition))
299
302
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
300
- .type(determineInputType(argumentDefinition.type, inputObjects))
303
+ .type(determineInputType(argumentDefinition.type, inputObjects, setOf () ))
301
304
.withDirectives(* buildDirectives(argumentDefinition.directives, setOf (), Introspection .DirectiveLocation .ARGUMENT_DEFINITION ))
302
305
field.argument(argumentBuilder.build())
303
306
}
@@ -328,7 +331,7 @@ class SchemaParser internal constructor(
328
331
is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
329
332
is InputObjectTypeDefinition -> {
330
333
log.info(" Create input object" )
331
- createInputObject(typeDefinition, inputObjects)
334
+ createInputObject(typeDefinition, inputObjects, mutableSetOf () )
332
335
}
333
336
is TypeName -> {
334
337
val scalarType = customScalars[typeDefinition.name]
@@ -346,16 +349,19 @@ class SchemaParser internal constructor(
346
349
else -> throw SchemaError (" Unknown type: $typeDefinition " )
347
350
}
348
351
349
- private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >) =
350
- determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
352
+ private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >, referencingInputObjects : Set < String > ) =
353
+ determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects ) as GraphQLInputType
351
354
352
- private fun <T : Any > determineInputType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
355
+ private fun <T : Any > determineInputType (expectedType : KClass <T >,
356
+ typeDefinition : Type <* >, allowedTypeReferences : Set <String >,
357
+ inputObjects : List <GraphQLInputObjectType >,
358
+ referencingInputObjects : Set <String >): GraphQLType =
353
359
when (typeDefinition) {
354
360
is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
355
361
is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
356
362
is InputObjectTypeDefinition -> {
357
363
log.info(" Create input object" )
358
- createInputObject(typeDefinition, inputObjects)
364
+ createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet < String > )
359
365
}
360
366
is TypeName -> {
361
367
val scalarType = customScalars[typeDefinition.name]
@@ -373,9 +379,14 @@ class SchemaParser internal constructor(
373
379
} else {
374
380
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
375
381
if (filteredDefinitions.isNotEmpty()) {
376
- val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects)
377
- (inputObjects as MutableList ).add(inputObject)
378
- inputObject
382
+ val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
383
+ if (referencingInputObject != null ) {
384
+ GraphQLTypeReference (referencingInputObject)
385
+ } else {
386
+ val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects, referencingInputObjects as MutableSet <String >)
387
+ (inputObjects as MutableList ).add(inputObject)
388
+ inputObject
389
+ }
379
390
} else {
380
391
// todo: handle enum type
381
392
GraphQLTypeReference (typeDefinition.name)
0 commit comments