diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaClassScanner.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaClassScanner.kt index 0982f796..6eee7cf1 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaClassScanner.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaClassScanner.kt @@ -119,9 +119,10 @@ internal class SchemaClassScanner(initialDictionary: BiMap>, al // The dictionary doesn't need to know what classes are used with scalars. // In addition, scalars can have duplicate classes so that breaks the bi-map. // Input types can also be excluded from the dictionary, since it's only used for interfaces, unions, and enums. + // Union types can also be excluded, as their possible types are resolved recursively later val dictionary = try { Maps.unmodifiableBiMap(HashBiMap.create>().also { - dictionary.filter { it.value.typeClass != null && it.key !is InputObjectTypeDefinition }.mapValuesTo(it) { it.value.typeClass } + dictionary.filter { it.value.typeClass != null && it.key !is InputObjectTypeDefinition && it.key !is UnionTypeDefinition}.mapValuesTo(it) { it.value.typeClass } }) } catch (t: Throwable) { throw SchemaClassScannerError("Error creating bimap of type => class", t) @@ -173,8 +174,9 @@ internal class SchemaClassScanner(initialDictionary: BiMap>, al } private fun getAllObjectTypeMembersOfDiscoveredUnions(): List { + val unionTypeNames = dictionary.keys.filterIsInstance().map { union -> union.name }.toSet() return dictionary.keys.filterIsInstance().map { union -> - union.memberTypes.filterIsInstance().map { objectDefinitionsByName[it.name] ?: throw SchemaClassScannerError("No object type found with name '${it.name}' for union: $union") } + union.memberTypes.filterIsInstance().filter { !unionTypeNames.contains(it.name) }.map { objectDefinitionsByName[it.name] ?: throw SchemaClassScannerError("No object type found with name '${it.name}' for union: $union") } }.flatten().distinct() } diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParser.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParser.kt index 7c33eb0f..283ccfe8 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParser.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParser.kt @@ -190,12 +190,27 @@ class SchemaParser internal constructor(private val dictionary: TypeClassDiction .description(getDocumentation(definition)) .typeResolver(TypeResolverProxy()) + getLeafUnionObjects(definition, types).forEach { builder.possibleType(it) } + return builder.build() + } + + private fun getLeafUnionObjects(definition: UnionTypeDefinition, types: List): List { + val name = definition.name + val leafObjects = mutableListOf() + definition.memberTypes.forEach { val typeName = (it as TypeName).name - builder.possibleType(types.find { it.name == typeName } ?: throw SchemaError("Expected object type '$typeName' for union type '$name', but found none!")) - } - return builder.build() + // Is this a nested union? If so, expand + val nestedUnion : UnionTypeDefinition? = unionDefinitions.find { otherDefinition -> typeName == otherDefinition.name } + + if (nestedUnion != null) { + leafObjects.addAll(getLeafUnionObjects(nestedUnion, types)) + } else { + leafObjects.add(types.find { it.name == typeName } ?: throw SchemaError("Expected object type '$typeName' for union type '$name', but found none!")) + } + } + return leafObjects } private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition : FieldDefinition): GraphQLFieldDefinition.Builder { diff --git a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy index 316c4779..6a0290f3 100644 --- a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy +++ b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy @@ -113,6 +113,31 @@ class EndToEndSpec extends Specification { data.allItems } + def "generated schema should handle nested union types"() { + when: + def data = Utils.assertNoGraphQlErrors(gql) { + ''' + { + nestedUnionItems { + ... on Item { + itemId: id + } + ... on OtherItem { + otherItemId: id + } + ... on ThirdItem { + thirdItemId: id + } + } + } + ''' + } + + then: + // TODO, flimsy. Need to test with two different numbers of items, because the mutation test may have run before this test. + data.nestedUnionItems == [[itemId: 0], [itemId: 1], [otherItemId: 0], [otherItemId: 1], [thirdItemId: 100]] || data.nestedUnionItems == [[itemId: 0], [itemId: 1], [itemId: 2], [otherItemId: 0], [otherItemId: 1], [thirdItemId: 100]] + } + def "generated schema should handle scalar types"() { when: def data = Utils.assertNoGraphQlErrors(gql) { diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt index 2b867b9b..87d25d68 100644 --- a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt +++ b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt @@ -13,6 +13,7 @@ fun createSchema() = SchemaParser.newParser() .resolvers(Query(), Mutation(), Subscription(), ItemResolver(), UnusedRootResolver(), UnusedResolver()) .scalars(CustomUUIDScalar) .dictionary("OtherItem", OtherItemWithWrongName::class.java) + .dictionary("ThirdItem", ThirdItem::class.java) .build() .makeExecutableSchema() @@ -27,6 +28,8 @@ type Query { items(itemsInput: ItemSearchInput!): [Item!] optionalItem(itemsInput: ItemSearchInput!): Item allItems: [AllItems!] + otherUnionItems: [OtherUnion!] + nestedUnionItems: [NestedUnion!] itemsByInterface: [ItemInterface!] itemByUUID(uuid: UUID!): Item itemsWithOptionalInput(itemsInput: ItemSearchInput): [Item!] @@ -115,6 +118,14 @@ interface ItemInterface { union AllItems = Item | OtherItem +type ThirdItem { + id: Int! +} + +union OtherUnion = Item | ThirdItem + +union NestedUnion = OtherUnion | OtherItem + type Tag { id: Int! name: String! @@ -132,11 +143,17 @@ val otherItems = mutableListOf( OtherItemWithWrongName(1, "otherItem2", Type.TYPE_2, UUID.fromString("38f685f1-b460-4a54-d17f-7fd69e8cf3f8")) ) +val thirdItems = mutableListOf( + ThirdItem(100) +) + class Query: GraphQLQueryResolver, ListListResolver() { fun isEmpty() = items.isEmpty() fun items(input: ItemSearchInput): List = items.filter { it.name == input.name } fun optionalItem(input: ItemSearchInput) = items(input).firstOrNull()?.let { Optional.of(it) } ?: Optional.empty() fun allItems(): List = items + otherItems + fun otherUnionItems(): List = items + thirdItems + fun nestedUnionItems(): List = items + otherItems + thirdItems fun itemsByInterface(): List = items + otherItems fun itemByUUID(uuid: UUID): Item? = items.find { it.uuid == uuid } fun itemsWithOptionalInput(input: ItemSearchInput?) = if(input == null) items else items(input) @@ -195,6 +212,7 @@ interface ItemInterface { enum class Type { TYPE_1, TYPE_2 } data class Item(val id: Int, override val name: String, override val type: Type, override val uuid:UUID, val tags: List) : ItemInterface data class OtherItemWithWrongName(val id: Int, override val name: String, override val type: Type, override val uuid:UUID) : ItemInterface +data class ThirdItem(val id: Int) data class Tag(val id: Int, val name: String) data class ItemSearchInput(val name: String) data class NewItemInput(val name: String, val type: Type)