Skip to content

Commit

Permalink
Consistently handle proxied beans in DgsSchemaProvider
Browse files Browse the repository at this point in the history
Ensure that we are consistently handling proxied beans in DgsSchemaProvider by
refactoring some of the logic to retrieve annotated methods; after fetching
DgsComponent-annotated beans, use a new internal wrapper class, DgsBean, that
correctly handles proxies, and includes a helper method to retrieve all methods
decorated with a particular annotation. In addition to making the handling consistent,
moving the logic to the DgsBean wrapper class avoids repeated unwrapping of proxies
and allocations.
  • Loading branch information
Patrick Strawderman authored and kilink committed Jan 19, 2024
1 parent 835e63f commit 734ad67
Showing 1 changed file with 45 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class DgsSchemaProvider(
*/
fun isFieldTracingInstrumentationEnabled(field: String): Boolean {
return schemaReadWriteLock.read {
dataFetcherTracingInstrumentationEnabled.getOrDefault(field, true)
dataFetcherTracingInstrumentationEnabled[field] ?: true
}
}

Expand All @@ -137,7 +137,7 @@ class DgsSchemaProvider(
*/
fun isFieldMetricsInstrumentationEnabled(field: String): Boolean {
return schemaReadWriteLock.read {
dataFetcherMetricsInstrumentationEnabled.getOrDefault(field, true)
dataFetcherMetricsInstrumentationEnabled[field] ?: true
}
}

Expand All @@ -155,14 +155,13 @@ class DgsSchemaProvider(

private fun computeSchema(schema: String? = null, fieldVisibility: GraphqlFieldVisibility): GraphQLSchema {
val startTime = System.currentTimeMillis()
val dgsComponents = applicationContext.getBeansWithAnnotation<DgsComponent>().values.let { beans ->
if (componentFilter != null) beans.filter(componentFilter) else beans
}
val dgsComponents = applicationContext.getBeansWithAnnotation<DgsComponent>().values.asSequence()
.let { beans -> if (componentFilter != null) beans.filter(componentFilter) else beans }
.map { bean -> DgsBean(bean) }
.toList()

var mergedRegistry = if (schema == null) {
val hasDynamicTypeRegistry = dgsComponents.any {
it.javaClass.methods.any { m -> m.isAnnotationPresent(DgsTypeDefinitionRegistry::class.java) }
}
val hasDynamicTypeRegistry = dgsComponents.any { it.annotatedMethods<DgsTypeDefinitionRegistry>().any() }
val readerBuilder = MultiSourceReader.newMultiSourceReader()
.trackData(false)
for (schemaFile in findSchemaFiles(hasDynamicTypeRegistry)) {
Expand Down Expand Up @@ -239,30 +238,28 @@ class DgsSchemaProvider(
}

private fun invokeDgsTypeDefinitionRegistry(
dgsComponent: Any,
dgsComponent: DgsBean,
registry: TypeDefinitionRegistry
): TypeDefinitionRegistry? {
return dgsComponent.javaClass.methods.asSequence()
.filter { it.isAnnotationPresent(DgsTypeDefinitionRegistry::class.java) }
return dgsComponent.annotatedMethods<DgsTypeDefinitionRegistry>()
.map { method ->
if (method.returnType != TypeDefinitionRegistry::class.java) {
throw InvalidDgsConfigurationException("Method annotated with @DgsTypeDefinitionRegistry must have return type TypeDefinitionRegistry")
}
if (method.parameterCount == 1 && method.parameterTypes[0] == TypeDefinitionRegistry::class.java) {
ReflectionUtils.invokeMethod(method, dgsComponent, registry) as TypeDefinitionRegistry
ReflectionUtils.invokeMethod(method, dgsComponent.instance, registry) as TypeDefinitionRegistry
} else {
ReflectionUtils.invokeMethod(method, dgsComponent) as TypeDefinitionRegistry
ReflectionUtils.invokeMethod(method, dgsComponent.instance) as TypeDefinitionRegistry
}
}.reduceOrNull { a, b -> a.merge(b) }
}

private fun invokeDgsCodeRegistry(
dgsComponent: Any,
dgsComponent: DgsBean,
codeRegistryBuilder: GraphQLCodeRegistry.Builder,
registry: TypeDefinitionRegistry
) {
dgsComponent.javaClass.methods.asSequence()
.filter { it.isAnnotationPresent(DgsCodeRegistry::class.java) }
dgsComponent.annotatedMethods<DgsCodeRegistry>()
.forEach { method ->
if (method.returnType != GraphQLCodeRegistry.Builder::class.java) {
throw InvalidDgsConfigurationException("Method annotated with @DgsCodeRegistry must have return type GraphQLCodeRegistry.Builder")
Expand All @@ -272,13 +269,12 @@ class DgsSchemaProvider(
throw InvalidDgsConfigurationException("Method annotated with @DgsCodeRegistry must accept the following arguments: GraphQLCodeRegistry.Builder, TypeDefinitionRegistry. ${dgsComponent.javaClass.name}.${method.name} has the following arguments: ${method.parameterTypes.joinToString()}")
}

ReflectionUtils.invokeMethod(method, dgsComponent, codeRegistryBuilder, registry)
ReflectionUtils.invokeMethod(method, dgsComponent.instance, codeRegistryBuilder, registry)
}
}

private fun invokeDgsRuntimeWiring(dgsComponent: Any, runtimeWiringBuilder: RuntimeWiring.Builder) {
dgsComponent.javaClass.methods.asSequence()
.filter { it.isAnnotationPresent(DgsRuntimeWiring::class.java) }
private fun invokeDgsRuntimeWiring(dgsComponent: DgsBean, runtimeWiringBuilder: RuntimeWiring.Builder) {
dgsComponent.annotatedMethods<DgsRuntimeWiring>()
.forEach { method ->
if (method.returnType != RuntimeWiring.Builder::class.java) {
throw InvalidDgsConfigurationException("Method annotated with @DgsRuntimeWiring must have return type RuntimeWiring.Builder")
Expand All @@ -288,30 +284,29 @@ class DgsSchemaProvider(
throw InvalidDgsConfigurationException("Method annotated with @DgsRuntimeWiring must accept an argument of type RuntimeWiring.Builder. ${dgsComponent.javaClass.name}.${method.name} has the following arguments: ${method.parameterTypes.joinToString()}")
}

ReflectionUtils.invokeMethod(method, dgsComponent, runtimeWiringBuilder)
ReflectionUtils.invokeMethod(method, dgsComponent.instance, runtimeWiringBuilder)
}
}

private fun findDataFetchers(
dgsComponents: Collection<Any>,
dgsComponents: Collection<DgsBean>,
codeRegistryBuilder: GraphQLCodeRegistry.Builder,
typeDefinitionRegistry: TypeDefinitionRegistry
) {
dgsComponents.forEach { dgsComponent ->
val javaClass = AopUtils.getTargetClass(dgsComponent)
ReflectionUtils.getUniqueDeclaredMethods(javaClass, ReflectionUtils.USER_DECLARED_METHODS).asSequence()
dgsComponent.methods
.map { method ->
val mergedAnnotations = MergedAnnotations
.from(method, MergedAnnotations.SearchStrategy.TYPE_HIERARCHY)
Pair(method, mergedAnnotations)
method to mergedAnnotations
}
.filter { (_, mergedAnnotations) -> mergedAnnotations.isPresent(DgsData::class.java) }
.forEach { (method, mergedAnnotations) ->
val filteredMergedAnnotations =
mergedAnnotations
.stream(DgsData::class.java)
.filter { AopUtils.getTargetClass((it.source as Method).declaringClass) == AopUtils.getTargetClass(method.declaringClass) }
.collect(Collectors.toList())
.toList()
filteredMergedAnnotations.forEach { dgsDataAnnotation ->
registerDataFetcher(
typeDefinitionRegistry,
Expand All @@ -329,7 +324,7 @@ class DgsSchemaProvider(
private fun registerDataFetcher(
typeDefinitionRegistry: TypeDefinitionRegistry,
codeRegistryBuilder: GraphQLCodeRegistry.Builder,
dgsComponent: Any,
dgsComponent: DgsBean,
method: Method,
dgsDataAnnotation: MergedAnnotation<DgsData>,
mergedAnnotations: MergedAnnotations
Expand All @@ -342,7 +337,7 @@ class DgsSchemaProvider(
throw InvalidDgsConfigurationException("Duplicate data fetchers registered for $parentType.$field")
}

dataFetchers.add(DataFetcherReference(dgsComponent, method, mergedAnnotations, parentType, field))
dataFetchers += DataFetcherReference(dgsComponent.instance, method, mergedAnnotations, parentType, field)

val enableTracingInstrumentation = if (method.isAnnotationPresent(DgsEnableDataFetcherInstrumentation::class.java)) {
val dgsEnableDataFetcherInstrumentation =
Expand Down Expand Up @@ -382,7 +377,7 @@ class DgsSchemaProvider(
// register the base implementation for interfaces
if (!codeRegistryBuilder.hasDataFetcher(FieldCoordinates.coordinates(implType.name, field))) {
val dataFetcher =
createBasicDataFetcher(method, dgsComponent, parentType == "Subscription")
createBasicDataFetcher(method, dgsComponent.instance, parentType == "Subscription")
codeRegistryBuilder.dataFetcher(
FieldCoordinates.coordinates(implType.name, field),
dataFetcher
Expand All @@ -397,7 +392,7 @@ class DgsSchemaProvider(
is UnionTypeDefinition -> {
type.memberTypes.asSequence().filterIsInstance<TypeName>().forEach { memberType ->
val dataFetcher =
createBasicDataFetcher(method, dgsComponent, parentType == "Subscription")
createBasicDataFetcher(method, dgsComponent.instance, parentType == "Subscription")
codeRegistryBuilder.dataFetcher(
FieldCoordinates.coordinates(memberType.name, field),
dataFetcher
Expand All @@ -412,10 +407,10 @@ class DgsSchemaProvider(
getMatchingFieldOnObjectOrExtensions(methodClassName, type, field, typeDefinitionRegistry, parentType)
checkInputArgumentsAreValid(
method,
matchingField.inputValueDefinitions.map { it.name }.toSet()
matchingField.inputValueDefinitions.asSequence().map { it.name }.toSet()
)
}
val dataFetcher = createBasicDataFetcher(method, dgsComponent, parentType == "Subscription")
val dataFetcher = createBasicDataFetcher(method, dgsComponent.instance, parentType == "Subscription")
codeRegistryBuilder.dataFetcher(
FieldCoordinates.coordinates(parentType, field),
dataFetcher
Expand Down Expand Up @@ -500,12 +495,9 @@ class DgsSchemaProvider(
}
}

private fun findEntityFetchers(dgsComponents: Collection<Any>, registry: TypeDefinitionRegistry, runtimeWiring: RuntimeWiring) {
private fun findEntityFetchers(dgsComponents: Collection<DgsBean>, registry: TypeDefinitionRegistry, runtimeWiring: RuntimeWiring) {
dgsComponents.forEach { dgsComponent ->
val javaClass = AopUtils.getTargetClass(dgsComponent)

ReflectionUtils.getDeclaredMethods(javaClass).asSequence()
.filter { it.isAnnotationPresent(DgsEntityFetcher::class.java) }
dgsComponent.annotatedMethods<DgsEntityFetcher>()
.forEach { method ->
val dgsEntityFetcherAnnotation = method.getAnnotation(DgsEntityFetcher::class.java)

Expand All @@ -517,7 +509,7 @@ class DgsSchemaProvider(
dataFetcherMetricsInstrumentationEnabled["${"__entities"}.${dgsEntityFetcherAnnotation.name}"] =
enableInstrumentation

entityFetcherRegistry.entityFetchers[dgsEntityFetcherAnnotation.name] = dgsComponent to method
entityFetcherRegistry.entityFetchers[dgsEntityFetcherAnnotation.name] = dgsComponent.instance to method

val type = registry.getType(dgsEntityFetcherAnnotation.name)

Expand Down Expand Up @@ -593,16 +585,14 @@ class DgsSchemaProvider(
}

private fun findTypeResolvers(
dgsComponents: Collection<Any>,
dgsComponents: Collection<DgsBean>,
runtimeWiringBuilder: RuntimeWiring.Builder,
mergedRegistry: TypeDefinitionRegistry
) {
val registeredTypeResolvers = mutableSetOf<String>()

dgsComponents.forEach { dgsComponent ->
val javaClass = AopUtils.getTargetClass(dgsComponent)
javaClass.methods.asSequence()
.filter { it.isAnnotationPresent(DgsTypeResolver::class.java) }
dgsComponent.annotatedMethods<DgsTypeResolver>()
.forEach { method ->
val annotation = method.getAnnotation(DgsTypeResolver::class.java)

Expand All @@ -622,22 +612,21 @@ class DgsSchemaProvider(
val defaultTypeResolver = method.getAnnotation(DgsDefaultTypeResolver::class.java)
if (defaultTypeResolver != null) {
overrideTypeResolver = dgsComponents.any { component ->
component.javaClass.methods.any { method ->
method.isAnnotationPresent(DgsTypeResolver::class.java) &&
method.getAnnotation(DgsTypeResolver::class.java).name == annotation.name &&
component != dgsComponent
component != dgsComponent && component.annotatedMethods<DgsTypeResolver>().any { method ->
method.getAnnotation(DgsTypeResolver::class.java).name == annotation.name
}
}
}
// do not add the default resolver if another resolver with the same name is present
if (defaultTypeResolver == null || !overrideTypeResolver) {
registeredTypeResolvers.add(annotation.name)
registeredTypeResolvers += annotation.name

val dgsComponentInstance = dgsComponent.instance
runtimeWiringBuilder.type(
TypeRuntimeWiring.newTypeWiring(annotation.name)
.typeResolver { env: TypeResolutionEnvironment ->
val typeName: String? =
ReflectionUtils.invokeMethod(method, dgsComponent, env.getObject()) as String?
ReflectionUtils.invokeMethod(method, dgsComponentInstance, env.getObject()) as String?
if (typeName != null) {
env.schema.getObjectType(typeName)
} else {
Expand Down Expand Up @@ -700,7 +689,7 @@ class DgsSchemaProvider(

private fun findDirectives(applicationContext: ApplicationContext, runtimeWiringBuilder: RuntimeWiring.Builder) {
applicationContext.getBeansWithAnnotation<DgsDirective>().values.forEach { directiveComponent ->
val annotation = directiveComponent::class.java.getAnnotation(DgsDirective::class.java)
val annotation = AopUtils.getTargetClass(directiveComponent).getAnnotation(DgsDirective::class.java)
when (directiveComponent) {
is SchemaDirectiveWiring ->
if (annotation.name.isNotBlank()) {
Expand Down Expand Up @@ -745,5 +734,12 @@ class DgsSchemaProvider(
companion object {
const val DEFAULT_SCHEMA_LOCATION = "classpath*:schema/**/*.graphql*"
private val logger: Logger = LoggerFactory.getLogger(DgsSchemaProvider::class.java)
private data class DgsBean(val instance: Any, val targetClass: Class<*> = AopUtils.getTargetClass(instance)) {
private val cachedMethods = ReflectionUtils.getUniqueDeclaredMethods(targetClass, ReflectionUtils.USER_DECLARED_METHODS)
val methods: Sequence<Method> get() = cachedMethods.asSequence()

inline fun <reified T : Annotation> annotatedMethods(): Sequence<Method> =
methods.filter { it.isAnnotationPresent(T::class.java) }
}
}
}

0 comments on commit 734ad67

Please sign in to comment.