From 90db3886379768cad23ee9ab11d38225ad4d9f40 Mon Sep 17 00:00:00 2001 From: iparadiso <111398937+iparadiso@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:03:28 -0800 Subject: [PATCH 1/2] support bean method-level @DgsDataLoader annotations --- .../DgsInputArgumentConfiguration.kt | 4 +- .../netflix/graphql/dgs/DgsDataLoader.java | 3 +- .../graphql/dgs/DgsCodeRegistryBuilder.kt | 4 +- .../graphql/dgs/DgsDataFetchingEnvironment.kt | 41 +++++- .../DefaultDgsFederationResolver.kt | 16 ++- .../dgs/internal/DgsDataLoaderProvider.kt | 38 ++++-- .../graphql/dgs/internal/DgsSchemaProvider.kt | 5 +- ...DataFetchingEnvironmentArgumentResolver.kt | 7 +- .../dgs/ExampleBatchLoaderFromBean.java | 67 ++++++++++ .../dgs/ExampleBatchLoaderFromBeanName.java | 67 ++++++++++ .../dgs/DataFetcherWithDirectivesTest.kt | 5 +- .../dgs/DefaultDgsFederationResolverTest.kt | 33 ++--- ...DgsDataFetchingEnvironmentIsArgumentSet.kt | 2 +- .../dgs/DgsDataFetchingEnvironmentTest.kt | 121 +++++++----------- .../graphql/dgs/DgsDataLoaderProviderTest.kt | 1 + .../graphql/dgs/DgsSchemaProviderTest.kt | 2 +- .../graphql/dgs/InterfaceDataFetchersTest.kt | 5 +- .../graphql/dgs/context/DgsContextTest.kt | 2 +- .../graphql/dgs/internal/InputArgumentTest.kt | 2 +- 19 files changed, 306 insertions(+), 119 deletions(-) create mode 100644 graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBean.java create mode 100644 graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBeanName.java diff --git a/graphql-dgs-spring-boot-oss-autoconfigure/src/main/kotlin/com/netflix/graphql/dgs/autoconfig/DgsInputArgumentConfiguration.kt b/graphql-dgs-spring-boot-oss-autoconfigure/src/main/kotlin/com/netflix/graphql/dgs/autoconfig/DgsInputArgumentConfiguration.kt index b0d55b0fa..b8fd74db9 100644 --- a/graphql-dgs-spring-boot-oss-autoconfigure/src/main/kotlin/com/netflix/graphql/dgs/autoconfig/DgsInputArgumentConfiguration.kt +++ b/graphql-dgs-spring-boot-oss-autoconfigure/src/main/kotlin/com/netflix/graphql/dgs/autoconfig/DgsInputArgumentConfiguration.kt @@ -24,6 +24,7 @@ import com.netflix.graphql.dgs.internal.method.DataFetchingEnvironmentArgumentRe import com.netflix.graphql.dgs.internal.method.FallbackEnvironmentArgumentResolver import com.netflix.graphql.dgs.internal.method.InputArgumentResolver import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean +import org.springframework.context.ApplicationContext import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.core.Ordered @@ -36,7 +37,8 @@ open class DgsInputArgumentConfiguration { open fun inputArgumentResolver(inputObjectMapper: InputObjectMapper): ArgumentResolver = InputArgumentResolver(inputObjectMapper) @Bean - open fun dataFetchingEnvironmentArgumentResolver(): ArgumentResolver = DataFetchingEnvironmentArgumentResolver() + open fun dataFetchingEnvironmentArgumentResolver(context: ApplicationContext): ArgumentResolver = + DataFetchingEnvironmentArgumentResolver(context) @Bean open fun coroutineArgumentResolver(): ArgumentResolver = ContinuationArgumentResolver() diff --git a/graphql-dgs/src/main/java/com/netflix/graphql/dgs/DgsDataLoader.java b/graphql-dgs/src/main/java/com/netflix/graphql/dgs/DgsDataLoader.java index ee2b47993..0053a0de4 100644 --- a/graphql-dgs/src/main/java/com/netflix/graphql/dgs/DgsDataLoader.java +++ b/graphql-dgs/src/main/java/com/netflix/graphql/dgs/DgsDataLoader.java @@ -17,7 +17,6 @@ package com.netflix.graphql.dgs; import com.netflix.graphql.dgs.internal.utils.DataLoaderNameUtil; -import org.dataloader.registries.DispatchPredicate; import org.springframework.stereotype.Component; import java.lang.annotation.ElementType; @@ -31,7 +30,7 @@ * The class or field must implement one of the BatchLoader interfaces. * See https://netflix.github.io/dgs/data-loaders/ */ -@Target({ElementType.TYPE, ElementType.FIELD}) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Component @Inherited diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsCodeRegistryBuilder.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsCodeRegistryBuilder.kt index 9bab1700b..0e784343e 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsCodeRegistryBuilder.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsCodeRegistryBuilder.kt @@ -23,6 +23,7 @@ import graphql.schema.DataFetchingEnvironment import graphql.schema.FieldCoordinates import graphql.schema.GraphQLCodeRegistry import graphql.schema.GraphQLFieldDefinition +import org.springframework.context.ApplicationContext /** * Utility wrapper for [GraphQLCodeRegistry.Builder] which provides @@ -32,6 +33,7 @@ import graphql.schema.GraphQLFieldDefinition class DgsCodeRegistryBuilder( private val dataFetcherResultProcessors: List, private val graphQLCodeRegistry: GraphQLCodeRegistry.Builder, + private val ctx: ApplicationContext, ) { fun dataFetcher( coordinates: FieldCoordinates, @@ -67,7 +69,7 @@ class DgsCodeRegistryBuilder( if (dfe is DgsDataFetchingEnvironment) { dfe } else { - DgsDataFetchingEnvironment(dfe) + DgsDataFetchingEnvironment(dfe, ctx) } return processor.process(result, env) } diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironment.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironment.kt index 9229725f6..1a840441f 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironment.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironment.kt @@ -22,10 +22,13 @@ import com.netflix.graphql.dgs.exceptions.NoDataLoaderFoundException import com.netflix.graphql.dgs.internal.utils.DataLoaderNameUtil import graphql.schema.DataFetchingEnvironment import org.dataloader.DataLoader -import java.util.* +import org.springframework.context.ApplicationContext +import org.springframework.context.ConfigurableApplicationContext +import org.springframework.core.type.StandardMethodMetadata class DgsDataFetchingEnvironment( private val dfe: DataFetchingEnvironment, + private val ctx: ApplicationContext, ) : DataFetchingEnvironment by dfe { fun getDfe(): DataFetchingEnvironment = this.dfe @@ -45,14 +48,42 @@ class DgsDataFetchingEnvironment( DataLoaderNameUtil.getDataLoaderName(loaderClass, annotation) } else { val loaders = loaderClass.fields.filter { it.isAnnotationPresent(DgsDataLoader::class.java) } - if (loaders.size > 1) throw MultipleDataLoadersDefinedException(loaderClass) - val loaderField = loaders.firstOrNull() ?: throw NoDataLoaderFoundException(loaderClass) - val theAnnotation = loaderField.getAnnotation(DgsDataLoader::class.java) - theAnnotation.name + if (loaders.isEmpty()) { + // annotation is not on the class, but potentially on the Bean definition + tryGetDataLoaderFromBeanDefinition(loaderClass) + } else { + if (loaders.size > 1) throw MultipleDataLoadersDefinedException(loaderClass) + val loaderField = loaders.firstOrNull() ?: throw NoDataLoaderFoundException(loaderClass) + val theAnnotation = loaderField.getAnnotation(DgsDataLoader::class.java) + theAnnotation.name + } } + return getDataLoader(loaderName) ?: throw NoDataLoaderFoundException("DataLoader with name $loaderName not found") } + private fun tryGetDataLoaderFromBeanDefinition(loaderClass: Class<*>): String { + var name = loaderClass.simpleName + if (ctx is ConfigurableApplicationContext) { + val beansOfType = ctx.beanFactory.getBeansOfType(loaderClass) + if (beansOfType.isEmpty()) { + throw NoDataLoaderFoundException(loaderClass) + } + if (beansOfType.size > 1) { + throw MultipleDataLoadersDefinedException(loaderClass) + } + val beanName = beansOfType.keys.first() + val beanDefinition = ctx.beanFactory.getBeanDefinition(beanName) + if (beanDefinition.source is StandardMethodMetadata) { + val methodMetadata = beanDefinition.source as StandardMethodMetadata + val method = methodMetadata.introspectedMethod + val methodAnnotation = method.getAnnotation(DgsDataLoader::class.java) + name = DataLoaderNameUtil.getDataLoaderName(loaderClass, methodAnnotation) + } + } + return name + } + /** * Check if an argument is explicitly set using "argument.nested.property" or "argument->nested->property" syntax. * Note that this requires String splitting which is expensive for hot code paths. diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/federation/DefaultDgsFederationResolver.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/federation/DefaultDgsFederationResolver.kt index 67e9b0bff..894c4b1a9 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/federation/DefaultDgsFederationResolver.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/federation/DefaultDgsFederationResolver.kt @@ -40,6 +40,7 @@ import org.dataloader.Try import org.slf4j.Logger import org.slf4j.LoggerFactory import org.springframework.beans.factory.annotation.Autowired +import org.springframework.context.ApplicationContext import org.springframework.util.ReflectionUtils import reactor.core.publisher.Mono import java.lang.reflect.InvocationTargetException @@ -59,9 +60,11 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver { constructor( entityFetcherRegistry: EntityFetcherRegistry, dataFetcherExceptionHandler: Optional, + applicationContext: ApplicationContext, ) : this() { this.entityFetcherRegistry = entityFetcherRegistry - dgsExceptionHandler = dataFetcherExceptionHandler + this.dgsExceptionHandler = dataFetcherExceptionHandler + this.applicationContext = applicationContext } /** @@ -73,6 +76,9 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver { @Autowired lateinit var dgsExceptionHandler: Optional + @Autowired + lateinit var applicationContext: ApplicationContext + private val entitiesDataFetcher: DataFetcher = DataFetcher { env -> dgsEntityFetchers(env) } override fun entitiesFetcher(): DataFetcher = entitiesDataFetcher @@ -139,7 +145,12 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver { val result = if (parameterTypes.last().isAssignableFrom(DgsDataFetchingEnvironment::class.java)) { - ReflectionUtils.invokeMethod(method, target, coercedValues, DgsDataFetchingEnvironment(env)) + ReflectionUtils.invokeMethod( + method, + target, + coercedValues, + DgsDataFetchingEnvironment(env, applicationContext), + ) } else { ReflectionUtils.invokeMethod(method, target, coercedValues) } @@ -216,6 +227,7 @@ open class DefaultDgsFederationResolver() : DgsFederationResolver { val dfe = if (env is DgsDataFetchingEnvironment) env.getDfe() else env return DgsDataFetchingEnvironment( DataFetchingEnvironmentImpl.newDataFetchingEnvironment(dfe).executionStepInfo(executionStepInfoWithPath).build(), + applicationContext, ) } diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt index 9e928a129..541304640 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt @@ -37,6 +37,8 @@ import org.slf4j.LoggerFactory import org.springframework.aop.support.AopUtils import org.springframework.beans.factory.NoSuchBeanDefinitionException import org.springframework.context.ApplicationContext +import org.springframework.context.ConfigurableApplicationContext +import org.springframework.core.type.StandardMethodMetadata import org.springframework.util.ReflectionUtils import java.time.Duration import java.util.concurrent.Executors @@ -134,18 +136,36 @@ class DgsDataLoaderProvider( private fun addDataLoaderComponents() { val dataLoaders = applicationContext.getBeansWithAnnotation(DgsDataLoader::class.java) - dataLoaders.values.forEach { dgsComponent -> - val javaClass = AopUtils.getTargetClass(dgsComponent) + dataLoaders.forEach { (beanName, beanInstance) -> + val javaClass = AopUtils.getTargetClass(beanInstance) + + // check for class-level annotations val annotation = javaClass.getAnnotation(DgsDataLoader::class.java) - val predicateField = javaClass.declaredFields.find { it.isAnnotationPresent(DgsDispatchPredicate::class.java) } - if (predicateField != null) { - ReflectionUtils.makeAccessible(predicateField) - val dispatchPredicate = predicateField.get(dgsComponent) - if (dispatchPredicate is DispatchPredicate) { - addDataLoaders(dgsComponent, javaClass, annotation, dispatchPredicate) + if (annotation != null) { + val predicateField = + javaClass.declaredFields.find { it.isAnnotationPresent(DgsDispatchPredicate::class.java) } + if (predicateField != null) { + ReflectionUtils.makeAccessible(predicateField) + val dispatchPredicate = predicateField.get(beanInstance) + if (dispatchPredicate is DispatchPredicate) { + addDataLoaders(beanInstance, javaClass, annotation, dispatchPredicate) + } + } else { + addDataLoaders(beanInstance, javaClass, annotation, null) } } else { - addDataLoaders(dgsComponent, javaClass, annotation, null) + // Check for method-level bean annotations in configuration classes + if (applicationContext is ConfigurableApplicationContext) { + val beanDefinition = applicationContext.beanFactory.getBeanDefinition(beanName) + if (beanDefinition.source is StandardMethodMetadata) { + val methodMetadata = beanDefinition.source as StandardMethodMetadata + val method = methodMetadata.introspectedMethod + val methodAnnotation = method.getAnnotation(DgsDataLoader::class.java) + if (methodAnnotation != null) { + addDataLoaders(beanInstance, javaClass, methodAnnotation, null) + } + } + } } } } diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsSchemaProvider.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsSchemaProvider.kt index 39366444a..7f80f0802 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsSchemaProvider.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsSchemaProvider.kt @@ -183,7 +183,7 @@ class DgsSchemaProvider( val runtimeWiringBuilder = RuntimeWiring.newRuntimeWiring().codeRegistry(codeRegistryBuilder).fieldVisibility(fieldVisibility) - val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder) + val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder, applicationContext) dgsComponents .asSequence() @@ -217,6 +217,7 @@ class DgsSchemaProvider( DefaultDgsFederationResolver( entityFetcherRegistry, dataFetcherExceptionHandler, + applicationContext, ) } @@ -269,7 +270,7 @@ class DgsSchemaProvider( codeRegistryBuilder: GraphQLCodeRegistry.Builder, registry: TypeDefinitionRegistry, ) { - val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder) + val dgsCodeRegistryBuilder = DgsCodeRegistryBuilder(dataFetcherResultProcessors, codeRegistryBuilder, applicationContext) dgsComponent .annotatedMethods() diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/DataFetchingEnvironmentArgumentResolver.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/DataFetchingEnvironmentArgumentResolver.kt index ee73ba7b3..8a5fa9285 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/DataFetchingEnvironmentArgumentResolver.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/DataFetchingEnvironmentArgumentResolver.kt @@ -18,13 +18,16 @@ package com.netflix.graphql.dgs.internal.method import com.netflix.graphql.dgs.DgsDataFetchingEnvironment import graphql.schema.DataFetchingEnvironment +import org.springframework.context.ApplicationContext import org.springframework.core.MethodParameter /** * Resolves method arguments for parameters of type [DataFetchingEnvironment] * or [DgsDataFetchingEnvironment]. */ -class DataFetchingEnvironmentArgumentResolver : ArgumentResolver { +class DataFetchingEnvironmentArgumentResolver( + private val ctx: ApplicationContext, +) : ArgumentResolver { override fun supportsParameter(parameter: MethodParameter): Boolean = parameter.parameterType == DgsDataFetchingEnvironment::class.java || parameter.parameterType == DataFetchingEnvironment::class.java @@ -34,7 +37,7 @@ class DataFetchingEnvironmentArgumentResolver : ArgumentResolver { dfe: DataFetchingEnvironment, ): Any { if (parameter.parameterType == DgsDataFetchingEnvironment::class.java && dfe !is DgsDataFetchingEnvironment) { - return DgsDataFetchingEnvironment(dfe) + return DgsDataFetchingEnvironment(dfe, ctx) } return dfe } diff --git a/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBean.java b/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBean.java new file mode 100644 index 000000000..3ac15b9a7 --- /dev/null +++ b/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBean.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.graphql.dgs; + +import org.dataloader.BatchLoader; +import org.dataloader.DataLoader; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +@Configuration +public class ExampleBatchLoaderFromBean { + + @DgsDataLoader + @Bean + public DataLoaderBeanClass batchLoaderBean() { + return new DataLoaderBeanClass(); + } + + @DgsComponent + @Bean + public HelloFetcherWithFromBean helloFetcherFromBean() { + return new HelloFetcherWithFromBean(); + } + + class DataLoaderBeanClass implements BatchLoader { + + @Override + public CompletionStage> load(List keys) { + List values = new ArrayList<>(); + values.add("a"); + values.add("b"); + values.add("c"); + return CompletableFuture.supplyAsync(() -> values); + } + } + + class HelloFetcherWithFromBean { + + @DgsData(parentType = "Query", field = "hello") + public CompletableFuture someFetcher(DgsDataFetchingEnvironment dfe) { + // validate data loader retrieval by type + DataLoader loader = dfe.getDataLoader(DataLoaderBeanClass.class); + loader.load("a"); + loader.load("b"); + return loader.load("c"); + } + } +} diff --git a/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBeanName.java b/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBeanName.java new file mode 100644 index 000000000..3d49099c7 --- /dev/null +++ b/graphql-dgs/src/test/java/com/netflix/graphql/dgs/ExampleBatchLoaderFromBeanName.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.graphql.dgs; + +import org.dataloader.BatchLoader; +import org.dataloader.DataLoader; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; + +@Configuration +public class ExampleBatchLoaderFromBeanName { + + @DgsDataLoader(name = "batchLoaderBeanFromBean") + @Bean + public DataLoaderBeanClass batchLoaderBean() { + return new DataLoaderBeanClass(); + } + + @DgsComponent + @Bean + public HelloFetcherWithFromBean helloFetcherFromBean() { + return new HelloFetcherWithFromBean(); + } + + class DataLoaderBeanClass implements BatchLoader { + + @Override + public CompletionStage> load(List keys) { + List values = new ArrayList<>(); + values.add("a"); + values.add("b"); + values.add("c"); + return CompletableFuture.supplyAsync(() -> values); + } + } + + class HelloFetcherWithFromBean { + + @DgsData(parentType = "Query", field = "hello") + public CompletableFuture someFetcher(DgsDataFetchingEnvironment dfe) { + // validate data loader retrieval by name + DataLoader loader = dfe.getDataLoader("batchLoaderBeanFromBean"); + loader.load("a"); + loader.load("b"); + return loader.load("c"); + } + } +} diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DataFetcherWithDirectivesTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DataFetcherWithDirectivesTest.kt index 13364bec0..37f71d3cf 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DataFetcherWithDirectivesTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DataFetcherWithDirectivesTest.kt @@ -73,7 +73,10 @@ class DataFetcherWithDirectivesTest { applicationContext = applicationContextMock, federationResolver = Optional.empty(), existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), + methodDataFetcherFactory = + MethodDataFetcherFactory( + listOf(DataFetchingEnvironmentArgumentResolver(applicationContextMock)), + ), ) val schema = diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DefaultDgsFederationResolverTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DefaultDgsFederationResolverTest.kt index 6458a28dd..31d70a975 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DefaultDgsFederationResolverTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DefaultDgsFederationResolverTest.kt @@ -106,7 +106,7 @@ class DefaultDgsFederationResolverTest { val graphQLSchema: GraphQLSchema = buildGraphQLSchema(schema) val type = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .typeResolver() .getType( TypeResolutionParameters @@ -130,7 +130,7 @@ class DefaultDgsFederationResolverTest { val graphQLSchema: GraphQLSchema = buildGraphQLSchema(schema) val type = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .typeResolver() .getType( TypeResolutionParameters @@ -159,7 +159,7 @@ class DefaultDgsFederationResolverTest { val graphQLSchema: GraphQLSchema = buildGraphQLSchema(schema) val customTypeResolver = - object : DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) { + object : DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) { override fun typeMapping(): Map, String> = mapOf(Movie::class.java to "DgsMovie") } @@ -180,7 +180,7 @@ class DefaultDgsFederationResolverTest { val dataFetchingEnvironment = constructDFE(arguments) val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -197,7 +197,7 @@ class DefaultDgsFederationResolverTest { val dataFetchingEnvironment = constructDFE(arguments) val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -291,7 +291,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -351,7 +351,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -422,7 +422,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -449,7 +449,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -488,7 +488,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -572,7 +572,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(customExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(customExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -624,7 +624,7 @@ class DefaultDgsFederationResolverTest { val dataFetchingEnvironment = constructDFE(arguments) val result = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> @@ -670,7 +670,7 @@ class DefaultDgsFederationResolverTest { val dataFetchingEnvironment = constructDFE(arguments) val result = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> @@ -758,7 +758,7 @@ class DefaultDgsFederationResolverTest { // Invoke the entitiesFetcher to get the result val result = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(customExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(customExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> @@ -821,7 +821,7 @@ class DefaultDgsFederationResolverTest { val dataFetchingEnvironment = constructDFE(arguments) val result = - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.empty()) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.empty(), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> @@ -844,7 +844,7 @@ class DefaultDgsFederationResolverTest { val result = ( - DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler)) + DefaultDgsFederationResolver(entityFetcherRegistry, Optional.of(dgsExceptionHandler), applicationContextMock) .entitiesFetcher() .get(dataFetchingEnvironment) as CompletableFuture>> ) @@ -889,6 +889,7 @@ class DefaultDgsFederationResolverTest { .executionStepInfo(executionStepInfo) .mergedField(MergedField.newMergedField(Field("Movie")).build()) .build(), + applicationContextMock, ) } diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentIsArgumentSet.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentIsArgumentSet.kt index bd62bdcf0..c02f9e80b 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentIsArgumentSet.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentIsArgumentSet.kt @@ -43,7 +43,7 @@ class DgsDataFetchingEnvironmentIsArgumentSet { methodDataFetcherFactory = MethodDataFetcherFactory( listOf( - DataFetchingEnvironmentArgumentResolver(), + DataFetchingEnvironmentArgumentResolver(context), InputArgumentResolver( DefaultInputObjectMapper(), ), diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt index 71ba9e7c0..ec9032dce 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt @@ -26,7 +26,9 @@ import graphql.schema.DataFetchingEnvironment import org.dataloader.DataLoader import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test +import org.springframework.boot.autoconfigure.AutoConfigurations import org.springframework.boot.test.context.runner.ApplicationContextRunner +import org.springframework.context.ApplicationContext import java.util.Optional import java.util.concurrent.CompletableFuture import kotlin.reflect.KClass @@ -107,90 +109,63 @@ internal class DgsDataFetchingEnvironmentTest { @Test fun getDataLoader() { contextRunner.withBeans(ExampleBatchLoader::class, HelloFetcher::class).run { context -> - val provider = DgsDataLoaderProvider(context) - provider.findDataLoaders() - val dataLoaderRegistry = provider.buildRegistry() - - val schemaProvider = - DgsSchemaProvider( - applicationContext = context, - federationResolver = Optional.empty(), - existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), - ) - val schema = schemaProvider.schema().graphQLSchema - val build = GraphQL.newGraphQL(schema).build() - - val executionInput: ExecutionInput = - ExecutionInput - .newExecutionInput() - .query("{hello}") - .dataLoaderRegistry(dataLoaderRegistry) - .build() - val executionResult = build.execute(executionInput) - Assertions.assertTrue(executionResult.isDataPresent) - val result = executionResult.getData() as Map - Assertions.assertEquals("c", result["hello"]) + validateDataLoader(context) } } @Test fun getDataLoaderWithBasicDfe() { contextRunner.withBeans(HelloFetcherWithBasicDfe::class, ExampleBatchLoader::class).run { context -> - val provider = DgsDataLoaderProvider(context) - provider.findDataLoaders() - val dataLoaderRegistry = provider.buildRegistry() + validateDataLoader(context) + } + } - val schemaProvider = - DgsSchemaProvider( - applicationContext = context, - federationResolver = Optional.empty(), - existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), - ) - val schema = schemaProvider.schema().graphQLSchema - val build = GraphQL.newGraphQL(schema).build() + private fun validateDataLoader(context: ApplicationContext) { + val provider = DgsDataLoaderProvider(context) + provider.findDataLoaders() + val dataLoaderRegistry = provider.buildRegistry() + + val schemaProvider = + DgsSchemaProvider( + applicationContext = context, + federationResolver = Optional.empty(), + existingTypeDefinitionRegistry = Optional.empty(), + methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver(context))), + ) + val schema = schemaProvider.schema().graphQLSchema + val build = GraphQL.newGraphQL(schema).build() + + val executionInput: ExecutionInput = + ExecutionInput + .newExecutionInput() + .query("{hello}") + .dataLoaderRegistry(dataLoaderRegistry) + .build() + val executionResult = build.execute(executionInput) + Assertions.assertTrue(executionResult.isDataPresent) + val result = executionResult.getData() as Map + Assertions.assertEquals("c", result["hello"]) + } - val executionInput: ExecutionInput = - ExecutionInput - .newExecutionInput() - .query("{hello}") - .dataLoaderRegistry(dataLoaderRegistry) - .build() - val executionResult = build.execute(executionInput) - Assertions.assertTrue(executionResult.isDataPresent) - val result = executionResult.getData() as Map - Assertions.assertEquals("c", result["hello"]) - } + @Test + fun getDataLoaderFromBean() { + contextRunner + .withConfiguration(AutoConfigurations.of(ExampleBatchLoaderFromBean::class.java)) + .run { context -> + validateDataLoader(context) + } + + contextRunner + .withConfiguration(AutoConfigurations.of(ExampleBatchLoaderFromBeanName::class.java)) + .run { context -> + validateDataLoader(context) + } } @Test fun getDataLoaderFromField() { contextRunner.withBeans(HelloFetcherWithField::class, ExampleBatchLoaderFromField::class).run { context -> - val provider = DgsDataLoaderProvider(context) - provider.findDataLoaders() - val dataLoaderRegistry = provider.buildRegistry() - - val schemaProvider = - DgsSchemaProvider( - applicationContext = context, - federationResolver = Optional.empty(), - existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), - ) - val schema = schemaProvider.schema().graphQLSchema - val build = GraphQL.newGraphQL(schema).build() - - val executionInput: ExecutionInput = - ExecutionInput - .newExecutionInput() - .query("{hello}") - .dataLoaderRegistry(dataLoaderRegistry) - .build() - val executionResult = build.execute(executionInput) - Assertions.assertTrue(executionResult.isDataPresent) - val result = executionResult.getData() as Map - Assertions.assertEquals("c", result["hello"]) + validateDataLoader(context) } } @@ -236,7 +211,7 @@ internal class DgsDataFetchingEnvironmentTest { applicationContext = context, federationResolver = Optional.empty(), existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), + methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver(context))), ) val schema = schemaProvider.schema().graphQLSchema @@ -267,7 +242,7 @@ internal class DgsDataFetchingEnvironmentTest { applicationContext = context, federationResolver = Optional.empty(), existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), + methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver(context))), ) val schema = schemaProvider.schema().graphQLSchema diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataLoaderProviderTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataLoaderProviderTest.kt index d5294bc04..2cc71537f 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataLoaderProviderTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataLoaderProviderTest.kt @@ -209,6 +209,7 @@ class DgsDataLoaderProviderTest { .newDataFetchingEnvironment() .dataLoaderRegistry(dataLoaderRegistry) .build(), + context, ).getDataLoader(ExampleBatchLoaderWithoutName::class.java) Assertions.assertNotNull(dataLoader) } diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsSchemaProviderTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsSchemaProviderTest.kt index cfecdc0a7..b2f7e18f3 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsSchemaProviderTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsSchemaProviderTest.kt @@ -91,7 +91,7 @@ internal class DgsSchemaProviderTest { MethodDataFetcherFactory( listOf( InputArgumentResolver(DefaultInputObjectMapper()), - DataFetchingEnvironmentArgumentResolver(), + DataFetchingEnvironmentArgumentResolver(applicationContext), ), ), componentFilter = componentFilter, diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/InterfaceDataFetchersTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/InterfaceDataFetchersTest.kt index 86ee69471..65f1be654 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/InterfaceDataFetchersTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/InterfaceDataFetchersTest.kt @@ -99,7 +99,10 @@ class InterfaceDataFetchersTest { applicationContext = applicationContextMock, federationResolver = Optional.empty(), existingTypeDefinitionRegistry = Optional.empty(), - methodDataFetcherFactory = MethodDataFetcherFactory(listOf(DataFetchingEnvironmentArgumentResolver())), + methodDataFetcherFactory = + MethodDataFetcherFactory( + listOf(DataFetchingEnvironmentArgumentResolver(applicationContextMock)), + ), ) val schema = provider diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/context/DgsContextTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/context/DgsContextTest.kt index 32bc26a11..7a4b717e4 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/context/DgsContextTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/context/DgsContextTest.kt @@ -78,7 +78,7 @@ internal class DgsContextTest { MethodDataFetcherFactory( listOf( InputArgumentResolver(DefaultInputObjectMapper()), - DataFetchingEnvironmentArgumentResolver(), + DataFetchingEnvironmentArgumentResolver(context), ), ), ) diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/InputArgumentTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/InputArgumentTest.kt index d0265254a..aec8c0e16 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/InputArgumentTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/InputArgumentTest.kt @@ -88,7 +88,7 @@ internal class InputArgumentTest { MethodDataFetcherFactory( listOf( InputArgumentResolver(DefaultInputObjectMapper()), - DataFetchingEnvironmentArgumentResolver(), + DataFetchingEnvironmentArgumentResolver(applicationContext), FallbackEnvironmentArgumentResolver(DefaultInputObjectMapper()), ), ), From 5c1b83a4544cf4b809b9505cf2394573c4e6d3de Mon Sep 17 00:00:00 2001 From: iparadiso <111398937+iparadiso@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:22:38 -0800 Subject: [PATCH 2/2] lint fix --- .../com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt index ec9032dce..7f1565b77 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DgsDataFetchingEnvironmentTest.kt @@ -116,7 +116,7 @@ internal class DgsDataFetchingEnvironmentTest { @Test fun getDataLoaderWithBasicDfe() { contextRunner.withBeans(HelloFetcherWithBasicDfe::class, ExampleBatchLoader::class).run { context -> - validateDataLoader(context) + validateDataLoader(context) } }