diff --git a/graphql-dgs-example-shared/src/main/java/com/netflix/graphql/dgs/example/shared/datafetcher/HelloDataFetcher.java b/graphql-dgs-example-shared/src/main/java/com/netflix/graphql/dgs/example/shared/datafetcher/HelloDataFetcher.java index e093b809b..7d1a0c604 100644 --- a/graphql-dgs-example-shared/src/main/java/com/netflix/graphql/dgs/example/shared/datafetcher/HelloDataFetcher.java +++ b/graphql-dgs-example-shared/src/main/java/com/netflix/graphql/dgs/example/shared/datafetcher/HelloDataFetcher.java @@ -106,6 +106,14 @@ public CompletableFuture withDataLoaderGraphQLContext(DataFetchingEnviro return exampleLoaderWithContext.load(CONTRIBUTOR_ENABLED_CONTEXT_KEY); } + @DgsData(parentType = "Query", field = "withDataLoaderGraphQLContextWithFromDfe") + @DgsEnableDataFetcherInstrumentation + public CompletableFuture withDataLoaderGraphQLContextWithFromDfe(DataFetchingEnvironment dfe) { + dfe.getGraphQlContext().put(CONTRIBUTOR_ENABLED_CONTEXT_KEY, "override"); + DataLoader exampleLoaderWithContext = dfe.getDataLoader("exampleLoaderWithGraphQLContext"); + return exampleLoaderWithContext.load(CONTRIBUTOR_ENABLED_CONTEXT_KEY); + } + @DgsData(parentType = "Query", field = "withGraphqlException") public String withGraphqlException() { throw new GraphQLException("that's not going to work!"); diff --git a/graphql-dgs-example-shared/src/main/resources/schema/schema.graphqls b/graphql-dgs-example-shared/src/main/resources/schema/schema.graphqls index feeb75021..8fd0d3bc9 100644 --- a/graphql-dgs-example-shared/src/main/resources/schema/schema.graphqls +++ b/graphql-dgs-example-shared/src/main/resources/schema/schema.graphqls @@ -4,6 +4,7 @@ type Query { withContext: String withDataLoaderContext: String withDataLoaderGraphQLContext: String + withDataLoaderGraphQLContextWithFromDfe: String movies: [Movie] messageFromBatchLoader: String messageFromBatchLoaderWithGreetings: String diff --git a/graphql-dgs-spring-graphql-example-java/src/test/java/com/netflix/graphql/dgs/example/datafetcher/GraphQLContextContributorTest.java b/graphql-dgs-spring-graphql-example-java/src/test/java/com/netflix/graphql/dgs/example/datafetcher/GraphQLContextContributorTest.java index b0768210e..80e7816eb 100644 --- a/graphql-dgs-spring-graphql-example-java/src/test/java/com/netflix/graphql/dgs/example/datafetcher/GraphQLContextContributorTest.java +++ b/graphql-dgs-spring-graphql-example-java/src/test/java/com/netflix/graphql/dgs/example/datafetcher/GraphQLContextContributorTest.java @@ -61,4 +61,13 @@ void withDataloaderGraphQLContext() { String contributorEnabled = queryExecutor.executeAndExtractJsonPath("{ withDataLoaderGraphQLContext }", "data.withDataLoaderGraphQLContext", servletWebRequest); assertThat(contributorEnabled).isEqualTo("true"); } + + @Test + void withDataloaderGraphQLContextOverride() { + final MockHttpServletRequest mockServletRequest = new MockHttpServletRequest(); + mockServletRequest.addHeader(CONTEXT_CONTRIBUTOR_HEADER_NAME, CONTEXT_CONTRIBUTOR_HEADER_VALUE); + ServletWebRequest servletWebRequest = new ServletWebRequest(mockServletRequest); + String contributorEnabled = queryExecutor.executeAndExtractJsonPath("{ withDataLoaderGraphQLContextWithFromDfe }", "data.withDataLoaderGraphQLContextWithFromDfe", servletWebRequest); + assertThat(contributorEnabled).isEqualTo("override"); + } } \ No newline at end of file diff --git a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/SpringGraphQLDgsQueryExecutor.kt b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/SpringGraphQLDgsQueryExecutor.kt index 9f4050baf..40667d215 100644 --- a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/SpringGraphQLDgsQueryExecutor.kt +++ b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/SpringGraphQLDgsQueryExecutor.kt @@ -30,7 +30,6 @@ import com.netflix.graphql.dgs.internal.DgsDataLoaderProvider import com.netflix.graphql.dgs.internal.DgsQueryExecutorRequestCustomizer import com.netflix.graphql.dgs.internal.DgsWebMvcRequestData import graphql.ExecutionResult -import graphql.GraphQLContext import org.springframework.graphql.ExecutionGraphQlService import org.springframework.graphql.support.DefaultExecutionGraphQlRequest import org.springframework.http.HttpHeaders @@ -65,20 +64,9 @@ class SpringGraphQLDgsQueryExecutor( val httpRequest = requestCustomizer.apply(webRequest ?: RequestContextHolder.getRequestAttributes() as? WebRequest, headers) val dgsContext = dgsContextBuilder.build(DgsWebMvcRequestData(request.extensions, headers, httpRequest)) - val dataLoaderRegistry = - dgsDataLoaderProvider.buildRegistryWithContextSupplier { - val graphQLContext = request.toExecutionInput().graphQLContext - if (graphQLContextContributors.isNotEmpty()) { - val requestData = dgsContext.requestData - val builderForContributors = GraphQLContext.newContext() - graphQLContextContributors.forEach { it.contribute(builderForContributors, extensions, requestData) } - graphQLContext.putAll(builderForContributors) - } - - graphQLContext - } - - request.configureExecutionInput { _, builder -> + + request.configureExecutionInput { e, builder -> + val dataLoaderRegistry = dgsDataLoaderProvider.buildRegistryWithContextSupplier { e.graphQLContext } builder .context(dgsContext) .graphQLContext(dgsContext) diff --git a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webflux/DgsWebFluxGraphQLInterceptor.kt b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webflux/DgsWebFluxGraphQLInterceptor.kt index c40239baa..86f125800 100644 --- a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webflux/DgsWebFluxGraphQLInterceptor.kt +++ b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webflux/DgsWebFluxGraphQLInterceptor.kt @@ -19,14 +19,13 @@ package com.netflix.graphql.dgs.springgraphql.webflux import com.netflix.graphql.dgs.internal.DgsDataLoaderProvider import com.netflix.graphql.dgs.reactive.internal.DefaultDgsReactiveGraphQLContextBuilder import com.netflix.graphql.dgs.reactive.internal.DgsReactiveRequestData -import graphql.GraphQLContext +import org.dataloader.DataLoaderRegistry import org.springframework.graphql.server.WebGraphQlInterceptor import org.springframework.graphql.server.WebGraphQlRequest import org.springframework.graphql.server.WebGraphQlResponse import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter import org.springframework.web.reactive.function.server.ServerRequest import reactor.core.publisher.Mono -import java.util.concurrent.CompletableFuture class DgsWebFluxGraphQLInterceptor( private val dgsDataLoaderProvider: DgsDataLoaderProvider, @@ -48,21 +47,19 @@ class DgsWebFluxGraphQLInterceptor( ), ) }.flatMap { dgsContext -> - val graphQLContextFuture = CompletableFuture() - val dataLoaderRegistry = - dgsDataLoaderProvider.buildRegistryWithContextSupplier { graphQLContextFuture.get() } - - request.configureExecutionInput { _, builder -> + var dataLoaderRegistry: DataLoaderRegistry? = null + request.configureExecutionInput { e, builder -> + dataLoaderRegistry = dgsDataLoaderProvider.buildRegistryWithContextSupplier { e.graphQLContext } builder .context(dgsContext) .graphQLContext(dgsContext) .dataLoaderRegistry(dataLoaderRegistry) .build() } - graphQLContextFuture.complete(request.toExecutionInput().graphQLContext) + chain.next(request).doFinally { if (dataLoaderRegistry is AutoCloseable) { - dataLoaderRegistry.close() + (dataLoaderRegistry as AutoCloseable).close() } } } diff --git a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webmvc/DgsWebMvcGraphQLInterceptor.kt b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webmvc/DgsWebMvcGraphQLInterceptor.kt index ca0092b9a..6f9794823 100644 --- a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webmvc/DgsWebMvcGraphQLInterceptor.kt +++ b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/webmvc/DgsWebMvcGraphQLInterceptor.kt @@ -21,7 +21,7 @@ import com.netflix.graphql.dgs.internal.DefaultDgsGraphQLContextBuilder import com.netflix.graphql.dgs.internal.DgsDataLoaderProvider import com.netflix.graphql.dgs.internal.DgsWebMvcRequestData import com.netflix.graphql.dgs.springgraphql.autoconfig.DgsSpringGraphQLConfigurationProperties -import graphql.GraphQLContext +import org.dataloader.DataLoaderRegistry import org.springframework.graphql.server.WebGraphQlInterceptor import org.springframework.graphql.server.WebGraphQlRequest import org.springframework.graphql.server.WebGraphQlResponse @@ -56,21 +56,13 @@ class DgsWebMvcGraphQLInterceptor( } else { dgsContextBuilder.build(DgsWebMvcRequestData(request.extensions, request.headers)) } - val dataLoaderRegistry = - dgsDataLoaderProvider.buildRegistryWithContextSupplier { - val graphQLContext = request.toExecutionInput().graphQLContext - if (graphQLContextContributors.isNotEmpty()) { - val extensions = request.extensions - val requestData = dgsContext.requestData - val builderForContributors = GraphQLContext.newContext() - graphQLContextContributors.forEach { it.contribute(builderForContributors, extensions, requestData) } - graphQLContext.putAll(builderForContributors) - } - graphQLContext - } + var dataLoaderRegistry: DataLoaderRegistry? = null + request.configureExecutionInput { e, builder -> + + dataLoaderRegistry = + dgsDataLoaderProvider.buildRegistryWithContextSupplier { e.graphQLContext } - request.configureExecutionInput { _, builder -> builder .context(dgsContext) .graphQLContext(dgsContext) @@ -81,14 +73,14 @@ class DgsWebMvcGraphQLInterceptor( return if (dgsSpringConfigurationProperties.webmvc.asyncdispatch.enabled) { chain.next(request).doFinally { if (dataLoaderRegistry is AutoCloseable) { - dataLoaderRegistry.close() + (dataLoaderRegistry as AutoCloseable).close() } } } else { @Suppress("BlockingMethodInNonBlockingContext") val response = chain.next(request).block()!! if (dataLoaderRegistry is AutoCloseable) { - dataLoaderRegistry.close() + (dataLoaderRegistry as AutoCloseable).close() } return Mono.just(response) }