Skip to content

Commit

Permalink
Merge branch 'master' into fix/#2024
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbakker authored Nov 6, 2024
2 parents 9784ce6 + b4513c6 commit 06dac26
Show file tree
Hide file tree
Showing 19 changed files with 307 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.netflix.graphql.dgs.internal.method.DataFetchingEnvironmentArgumentRe
import com.netflix.graphql.dgs.internal.method.FallbackEnvironmentArgumentResolver
import com.netflix.graphql.dgs.internal.method.InputArgumentResolver
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean
import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.Ordered
Expand All @@ -36,7 +37,8 @@ open class DgsInputArgumentConfiguration {
open fun inputArgumentResolver(inputObjectMapper: InputObjectMapper): ArgumentResolver = InputArgumentResolver(inputObjectMapper)

@Bean
open fun dataFetchingEnvironmentArgumentResolver(): ArgumentResolver = DataFetchingEnvironmentArgumentResolver()
open fun dataFetchingEnvironmentArgumentResolver(context: ApplicationContext): ArgumentResolver =
DataFetchingEnvironmentArgumentResolver(context)

@Bean
open fun coroutineArgumentResolver(): ArgumentResolver = ContinuationArgumentResolver()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.netflix.graphql.dgs;

import com.netflix.graphql.dgs.internal.utils.DataLoaderNameUtil;
import org.dataloader.registries.DispatchPredicate;
import org.springframework.stereotype.Component;

import java.lang.annotation.ElementType;
Expand All @@ -31,7 +30,7 @@
* The class or field must implement one of the BatchLoader interfaces.
* See https://netflix.github.io/dgs/data-loaders/
*/
@Target({ElementType.TYPE, ElementType.FIELD})
@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Component
@Inherited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import graphql.schema.DataFetchingEnvironment
import graphql.schema.FieldCoordinates
import graphql.schema.GraphQLCodeRegistry
import graphql.schema.GraphQLFieldDefinition
import org.springframework.context.ApplicationContext

/**
* Utility wrapper for [GraphQLCodeRegistry.Builder] which provides
Expand All @@ -32,6 +33,7 @@ import graphql.schema.GraphQLFieldDefinition
class DgsCodeRegistryBuilder(
private val dataFetcherResultProcessors: List<DataFetcherResultProcessor>,
private val graphQLCodeRegistry: GraphQLCodeRegistry.Builder,
private val ctx: ApplicationContext,
) {
fun dataFetcher(
coordinates: FieldCoordinates,
Expand Down Expand Up @@ -67,7 +69,7 @@ class DgsCodeRegistryBuilder(
if (dfe is DgsDataFetchingEnvironment) {
dfe
} else {
DgsDataFetchingEnvironment(dfe)
DgsDataFetchingEnvironment(dfe, ctx)
}
return processor.process(result, env)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import com.netflix.graphql.dgs.exceptions.NoDataLoaderFoundException
import com.netflix.graphql.dgs.internal.utils.DataLoaderNameUtil
import graphql.schema.DataFetchingEnvironment
import org.dataloader.DataLoader
import java.util.*
import org.springframework.context.ApplicationContext
import org.springframework.context.ConfigurableApplicationContext
import org.springframework.core.type.StandardMethodMetadata

class DgsDataFetchingEnvironment(
private val dfe: DataFetchingEnvironment,
private val ctx: ApplicationContext,
) : DataFetchingEnvironment by dfe {
fun getDfe(): DataFetchingEnvironment = this.dfe

Expand All @@ -45,14 +48,42 @@ class DgsDataFetchingEnvironment(
DataLoaderNameUtil.getDataLoaderName(loaderClass, annotation)
} else {
val loaders = loaderClass.fields.filter { it.isAnnotationPresent(DgsDataLoader::class.java) }
if (loaders.size > 1) throw MultipleDataLoadersDefinedException(loaderClass)
val loaderField = loaders.firstOrNull() ?: throw NoDataLoaderFoundException(loaderClass)
val theAnnotation = loaderField.getAnnotation(DgsDataLoader::class.java)
theAnnotation.name
if (loaders.isEmpty()) {
// annotation is not on the class, but potentially on the Bean definition
tryGetDataLoaderFromBeanDefinition(loaderClass)
} else {
if (loaders.size > 1) throw MultipleDataLoadersDefinedException(loaderClass)
val loaderField = loaders.firstOrNull() ?: throw NoDataLoaderFoundException(loaderClass)
val theAnnotation = loaderField.getAnnotation(DgsDataLoader::class.java)
theAnnotation.name
}
}

return getDataLoader(loaderName) ?: throw NoDataLoaderFoundException("DataLoader with name $loaderName not found")
}

private fun tryGetDataLoaderFromBeanDefinition(loaderClass: Class<*>): String {
var name = loaderClass.simpleName
if (ctx is ConfigurableApplicationContext) {
val beansOfType = ctx.beanFactory.getBeansOfType(loaderClass)
if (beansOfType.isEmpty()) {
throw NoDataLoaderFoundException(loaderClass)
}
if (beansOfType.size > 1) {
throw MultipleDataLoadersDefinedException(loaderClass)
}
val beanName = beansOfType.keys.first()
val beanDefinition = ctx.beanFactory.getBeanDefinition(beanName)
if (beanDefinition.source is StandardMethodMetadata) {
val methodMetadata = beanDefinition.source as StandardMethodMetadata
val method = methodMetadata.introspectedMethod
val methodAnnotation = method.getAnnotation(DgsDataLoader::class.java)
name = DataLoaderNameUtil.getDataLoaderName(loaderClass, methodAnnotation)
}
}
return name
}

/**
* Check if an argument is explicitly set using "argument.nested.property" or "argument->nested->property" syntax.
* Note that this requires String splitting which is expensive for hot code paths.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.dataloader.Try
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext
import org.springframework.util.ReflectionUtils
import reactor.core.publisher.Mono
import java.lang.reflect.InvocationTargetException
Expand All @@ -59,9 +60,11 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver {
constructor(
entityFetcherRegistry: EntityFetcherRegistry,
dataFetcherExceptionHandler: Optional<DataFetcherExceptionHandler>,
applicationContext: ApplicationContext,
) : this() {
this.entityFetcherRegistry = entityFetcherRegistry
dgsExceptionHandler = dataFetcherExceptionHandler
this.dgsExceptionHandler = dataFetcherExceptionHandler
this.applicationContext = applicationContext
}

/**
Expand All @@ -73,6 +76,9 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver {
@Autowired
lateinit var dgsExceptionHandler: Optional<DataFetcherExceptionHandler>

@Autowired
lateinit var applicationContext: ApplicationContext

private val entitiesDataFetcher: DataFetcher<Any?> = DataFetcher { env -> dgsEntityFetchers(env) }

override fun entitiesFetcher(): DataFetcher<Any?> = entitiesDataFetcher
Expand Down Expand Up @@ -139,7 +145,12 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver {

val result =
if (parameterTypes.last().isAssignableFrom(DgsDataFetchingEnvironment::class.java)) {
ReflectionUtils.invokeMethod(method, target, coercedValues, DgsDataFetchingEnvironment(env))
ReflectionUtils.invokeMethod(
method,
target,
coercedValues,
DgsDataFetchingEnvironment(env, applicationContext),
)
} else {
ReflectionUtils.invokeMethod(method, target, coercedValues)
}
Expand Down Expand Up @@ -216,6 +227,7 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver {
val dfe = if (env is DgsDataFetchingEnvironment) env.getDfe() else env
return DgsDataFetchingEnvironment(
DataFetchingEnvironmentImpl.newDataFetchingEnvironment(dfe).executionStepInfo(executionStepInfoWithPath).build(),
applicationContext,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import org.slf4j.LoggerFactory
import org.springframework.aop.support.AopUtils
import org.springframework.beans.factory.NoSuchBeanDefinitionException
import org.springframework.context.ApplicationContext
import org.springframework.context.ConfigurableApplicationContext
import org.springframework.core.type.StandardMethodMetadata
import org.springframework.util.ReflectionUtils
import java.time.Duration
import java.util.concurrent.Executors
Expand Down Expand Up @@ -128,19 +130,37 @@ class DgsDataLoaderProvider(

private fun addDataLoaderComponents() {
val dataLoaders = applicationContext.getBeansWithAnnotation(DgsDataLoader::class.java)
dataLoaders.values.forEach { dgsComponent ->
val javaClass = AopUtils.getTargetClass(dgsComponent)
dataLoaders.forEach { (beanName, beanInstance) ->
val javaClass = AopUtils.getTargetClass(beanInstance)

// check for class-level annotations
val annotation = javaClass.getAnnotation(DgsDataLoader::class.java)
val dataLoaderName = DataLoaderNameUtil.getDataLoaderName(javaClass, annotation)
val predicateField = javaClass.declaredFields.find { it.isAnnotationPresent(DgsDispatchPredicate::class.java) }
if (predicateField != null) {
ReflectionUtils.makeAccessible(predicateField)
val dispatchPredicate = predicateField.get(dgsComponent)
if (dispatchPredicate is DispatchPredicate) {
addDataLoader(dgsComponent, dataLoaderName, javaClass, annotation, dispatchPredicate)
if (annotation != null) {
val dataLoaderName = DataLoaderNameUtil.getDataLoaderName(javaClass, annotation)
val predicateField = javaClass.declaredFields.find { it.isAnnotationPresent(DgsDispatchPredicate::class.java) }
if (predicateField != null) {
ReflectionUtils.makeAccessible(predicateField)
val dispatchPredicate = predicateField.get(beanInstance)
if (dispatchPredicate is DispatchPredicate) {
addDataLoader(beanInstance, dataLoaderName, javaClass, annotation, dispatchPredicate)
}
} else {
addDataLoader(beanInstance, dataLoaderName, javaClass, annotation)
}
} else {
addDataLoader(dgsComponent, dataLoaderName, javaClass, annotation)
// Check for method-level bean annotations in configuration classes
if (applicationContext is ConfigurableApplicationContext) {
val beanDefinition = applicationContext.beanFactory.getBeanDefinition(beanName)
if (beanDefinition.source is StandardMethodMetadata) {
val methodMetadata = beanDefinition.source as StandardMethodMetadata
val method = methodMetadata.introspectedMethod
val methodAnnotation = method.getAnnotation(DgsDataLoader::class.java)
if (methodAnnotation != null) {
val dataLoaderName = DataLoaderNameUtil.getDataLoaderName(javaClass, methodAnnotation)
addDataLoader(beanInstance, dataLoaderName, javaClass, methodAnnotation, null)
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DgsSchemaProvider(
val runtimeWiringBuilder =
RuntimeWiring.newRuntimeWiring().codeRegistry(codeRegistryBuilder).fieldVisibility(fieldVisibility)

val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder)
val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder, applicationContext)

dgsComponents
.asSequence()
Expand Down Expand Up @@ -218,6 +218,7 @@ class DgsSchemaProvider(
DefaultDgsFederationResolver(
entityFetcherRegistry,
dataFetcherExceptionHandler,
applicationContext,
)
}

Expand Down Expand Up @@ -270,7 +271,7 @@ class DgsSchemaProvider(
codeRegistryBuilder: GraphQLCodeRegistry.Builder,
registry: TypeDefinitionRegistry,
) {
val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder)
val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder, applicationContext)

dgsComponent
.annotatedMethods<DgsCodeRegistry>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ package com.netflix.graphql.dgs.internal.method

import com.netflix.graphql.dgs.DgsDataFetchingEnvironment
import graphql.schema.DataFetchingEnvironment
import org.springframework.context.ApplicationContext
import org.springframework.core.MethodParameter

/**
* Resolves method arguments for parameters of type [DataFetchingEnvironment]
* or [DgsDataFetchingEnvironment].
*/
class DataFetchingEnvironmentArgumentResolver : ArgumentResolver {
class DataFetchingEnvironmentArgumentResolver(
private val ctx: ApplicationContext,
) : ArgumentResolver {
override fun supportsParameter(parameter: MethodParameter): Boolean =
parameter.parameterType == DgsDataFetchingEnvironment::class.java ||
parameter.parameterType == DataFetchingEnvironment::class.java
Expand All @@ -34,7 +37,7 @@ class DataFetchingEnvironmentArgumentResolver : ArgumentResolver {
dfe: DataFetchingEnvironment,
): Any {
if (parameter.parameterType == DgsDataFetchingEnvironment::class.java && dfe !is DgsDataFetchingEnvironment) {
return DgsDataFetchingEnvironment(dfe)
return DgsDataFetchingEnvironment(dfe, ctx)
}
return dfe
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2020 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs;

import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;

@Configuration
public class ExampleBatchLoaderFromBean {

@DgsDataLoader
@Bean
public DataLoaderBeanClass batchLoaderBean() {
return new DataLoaderBeanClass();
}

@DgsComponent
@Bean
public HelloFetcherWithFromBean helloFetcherFromBean() {
return new HelloFetcherWithFromBean();
}

class DataLoaderBeanClass implements BatchLoader<String, String> {

@Override
public CompletionStage<List<String>> load(List<String> keys) {
List<String> values = new ArrayList<>();
values.add("a");
values.add("b");
values.add("c");
return CompletableFuture.supplyAsync(() -> values);
}
}

class HelloFetcherWithFromBean {

@DgsData(parentType = "Query", field = "hello")
public CompletableFuture<String> someFetcher(DgsDataFetchingEnvironment dfe) {
// validate data loader retrieval by type
DataLoader<String, String> loader = dfe.getDataLoader(DataLoaderBeanClass.class);
loader.load("a");
loader.load("b");
return loader.load("c");
}
}
}
Loading

0 comments on commit 06dac26

Please sign in to comment.