diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandler.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandler.kt index 23cfbdb8f..2f142d531 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandler.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandler.kt @@ -25,6 +25,7 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import org.slf4j.event.Level import org.springframework.util.ClassUtils +import java.lang.reflect.InvocationTargetException import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletionException @@ -50,6 +51,7 @@ open class DefaultDataFetcherExceptionHandler : DataFetcherExceptionHandler { isSpringSecurityAccessException( exception, ) -> TypedGraphQLError.newPermissionDeniedBuilder() + else -> TypedGraphQLError.newInternalErrorBuilder() } builder @@ -82,7 +84,12 @@ open class DefaultDataFetcherExceptionHandler : DataFetcherExceptionHandler { ) } - private fun unwrapCompletionException(e: Throwable): Throwable = if (e is CompletionException && e.cause != null) e.cause!! else e + private fun unwrapCompletionException(e: Throwable): Throwable = + when (e) { + is CompletionException -> unwrapCompletionException(e.cause ?: e) + is InvocationTargetException -> unwrapCompletionException(e.targetException) + else -> e + } protected val logger: Logger get() = DefaultDataFetcherExceptionHandler.logger diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandlerTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandlerTest.kt index 89edca0a7..4a97b30ab 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandlerTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/exceptions/DefaultDataFetcherExceptionHandlerTest.kt @@ -31,11 +31,12 @@ import io.mockk.spyk import io.mockk.verify import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertAll import org.slf4j.Logger import org.slf4j.event.Level import org.slf4j.spi.NOPLoggingEventBuilder import org.springframework.security.access.AccessDeniedException -import java.lang.IllegalStateException +import java.lang.reflect.InvocationTargetException import java.util.concurrent.CompletionException class DefaultDataFetcherExceptionHandlerTest { @@ -245,4 +246,23 @@ class DefaultDataFetcherExceptionHandlerTest { verify { loggerMock.atLevel(Level.ERROR) } confirmVerified(loggerMock) } + + @Test + fun `unwraps the invocation target exception`() { + val invocation = InvocationTargetException(IllegalStateException("I'm illegal!"), "Target invocation happened") + + val params = + DataFetcherExceptionHandlerParameters + .newExceptionParameters() + .exception(invocation) + .dataFetchingEnvironment(environment) + .build() + + val result = DefaultDataFetcherExceptionHandler().handleException(params).get() + + assertAll( + { assertThat(result.errors.size).isEqualTo(1) }, + { assertThat(result.errors[0].message).containsSubsequence("java.lang.IllegalStateException: I'm illegal!") }, + ) + } }