Skip to content

Commit

Permalink
not nullable post againt endpoint (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-olaveide authored Jan 19, 2024
1 parent 870dfd8 commit cb33892
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package no.nav.security.token.support.client.core.http
import no.nav.security.token.support.client.core.oauth2.OAuth2AccessTokenResponse

interface OAuth2HttpClient {
fun post(request : OAuth2HttpRequest) : OAuth2AccessTokenResponse?
fun post(req : OAuth2HttpRequest) : OAuth2AccessTokenResponse
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import no.nav.security.token.support.client.core.http.OAuth2HttpClient
import no.nav.security.token.support.client.core.http.OAuth2HttpHeaders
import no.nav.security.token.support.client.core.http.OAuth2HttpRequest

abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest?> internal constructor(private val oAuth2HttpClient : OAuth2HttpClient) {
abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest> internal constructor(private val oAuth2HttpClient : OAuth2HttpClient) {

protected abstract fun formParameters(grantRequest : T) : Map<String, String>

fun getTokenResponse(grantRequest : T) =
grantRequest?.clientProperties?.let {
grantRequest.clientProperties.let {
runCatching {
oAuth2HttpClient.post(OAuth2HttpRequest.builder(it.tokenEndpointUrl!!)
.oAuth2HttpHeaders(OAuth2HttpHeaders.of(tokenRequestHeaders(it)))
Expand Down Expand Up @@ -56,33 +56,31 @@ abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest?> intern

}

private fun defaultFormParameters(grantRequest : T) =
grantRequest?.clientProperties?.let {
private fun defaultFormParameters(grantRequest : T) : MutableMap<String, String> =
with(grantRequest.clientProperties) {
defaultClientAuthenticationFormParameters(grantRequest).apply {
put(GRANT_TYPE,grantRequest.grantType.value)
if (TOKEN_EXCHANGE != it.grantType) {
put(SCOPE, join(" ", it.scope))
if (TOKEN_EXCHANGE != grantType) {
put(SCOPE, join(" ", scope))
}
}
} ?: throw OAuth2ClientException("ClientProperties cannot be null")
}

private fun defaultClientAuthenticationFormParameters(grantRequest : T) =
grantRequest?.clientProperties?.let {
with(it) {
when (authentication.clientAuthMethod) {
CLIENT_SECRET_POST -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_SECRET, authentication.clientSecret!!)
}
PRIVATE_KEY_JWT -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_ASSERTION_TYPE, JWTAuthentication.CLIENT_ASSERTION_TYPE)
put(CLIENT_ASSERTION, ClientAssertion(tokenEndpointUrl!!, authentication).assertion())
}
else -> mutableMapOf()
with(grantRequest.clientProperties) {
when (authentication.clientAuthMethod) {
CLIENT_SECRET_POST -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_SECRET, authentication.clientSecret!!)
}
PRIVATE_KEY_JWT -> LinkedHashMap<String, String>().apply {
put(CLIENT_ID, authentication.clientId)
put(CLIENT_ASSERTION_TYPE, JWTAuthentication.CLIENT_ASSERTION_TYPE)
put(CLIENT_ASSERTION, ClientAssertion(tokenEndpointUrl!!, authentication).assertion())
}
else -> mutableMapOf()
}
} ?: throw OAuth2ClientException("ClientProperties cannot be null")
}

private fun basicAuth(username : String, password : String) =
UTF_8.newEncoder().run {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import no.nav.security.token.support.client.core.oauth2.OAuth2AccessTokenRespons

class SimpleOAuth2HttpClient : OAuth2HttpClient {

override fun post(request: OAuth2HttpRequest) =
override fun post(req: OAuth2HttpRequest) =
HttpRequest.newBuilder().apply {
configureRequest(request)
configureRequest(req)
}.build()
.sendRequest()
.processResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ internal class OnBehalfOfTokenClientTest {
.contains("requested_token_use=on_behalf_of")
.contains("assertion=$assertion")
assertThat(response).isNotNull()
assertThat(response?.accessToken).isNotBlank()
assertThat(response?.expiresAt).isPositive()
assertThat(response?.expiresIn).isPositive()
assertThat(response.accessToken).isNotBlank()
assertThat(response.expiresAt).isPositive()
assertThat(response.expiresIn).isPositive()
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package no.nav.security.token.support.client.spring.oauth2
import org.springframework.http.HttpHeaders
import org.springframework.util.LinkedMultiValueMap
import org.springframework.web.client.RestClient
import org.springframework.web.client.body
import no.nav.security.token.support.client.core.OAuth2ClientException
import no.nav.security.token.support.client.core.http.OAuth2HttpClient
import no.nav.security.token.support.client.core.http.OAuth2HttpRequest
Expand All @@ -11,19 +12,19 @@ import no.nav.security.token.support.client.core.oauth2.OAuth2AccessTokenRespons
open class DefaultOAuth2HttpClient(val restClient: RestClient) : OAuth2HttpClient {


override fun post(request: OAuth2HttpRequest) =
override fun post(req: OAuth2HttpRequest) =
restClient.post()
.uri(request.tokenEndpointUrl)
.headers { it.addAll(headers(request)) }
.uri(req.tokenEndpointUrl)
.headers { it.addAll(headers(req)) }
.body(LinkedMultiValueMap<String, String>().apply {
setAll(request.formParameters)
setAll(req.formParameters)
}).retrieve()
.onStatus({ it.isError }) { _, response ->
throw OAuth2ClientException("Received $response.statusCode from $request.tokenEndpointUrl")
throw OAuth2ClientException("Received $response.statusCode from $req.tokenEndpointUrl")
}
.body(OAuth2AccessTokenResponse::class.java)
.body<OAuth2AccessTokenResponse>() ?: throw OAuth2ClientException("No body in response from $req.tokenEndpointUrl")

private fun headers(req: OAuth2HttpRequest): HttpHeaders = HttpHeaders().apply { req.oAuth2HttpHeaders?.let { putAll(it.headers) } }
private fun headers(req: OAuth2HttpRequest): HttpHeaders = HttpHeaders().apply { putAll(req.oAuth2HttpHeaders.headers) }

override fun toString() = "$javaClass.simpleName [restClient=$restClient]"
}

0 comments on commit cb33892

Please sign in to comment.