diff --git a/swagger.go b/swagger.go index 92c246b..80eb048 100644 --- a/swagger.go +++ b/swagger.go @@ -30,6 +30,22 @@ type Config struct { Layout SwaggerLayout DefaultModelsExpandDepth ModelsExpandDepthType ShowExtensions bool + + // The information for OAuth2 integration, if any. + OAuth *OAuthConfig +} + +// OAuthConfig stores configuration for Swagger UI OAuth2 integration. See +// https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ for further details. +type OAuthConfig struct { + // The ID of the client sent to the OAuth2 IAM provider. + ClientId string + + // The OAuth2 realm that the client should operate in. If not applicable, use empty string. + Realm string + + // The name to display for the application in the authentication popup. + AppName string } // URL presents the url pointing to API definition (normally swagger.json or swagger.yaml). @@ -113,6 +129,12 @@ func AfterScript(js string) func(*Config) { } } +func OAuth(config *OAuthConfig) func(*Config) { + return func(c *Config) { + c.OAuth = config + } +} + type SwaggerLayout string const ( @@ -176,7 +198,6 @@ func newConfig(configFns ...func(*Config)) *Config { // Handler wraps `http.Handler` into `http.HandlerFunc`. func Handler(configFns ...func(*Config)) http.HandlerFunc { - config := newConfig(configFns...) // create a template with name @@ -334,6 +355,13 @@ window.onload = function() { defaultModelsExpandDepth: {{.DefaultModelsExpandDepth}}, showExtensions: {{.ShowExtensions}} }) + {{if .OAuth}} + ui.initOAuth({ + clientId: "{{.OAuth.ClientId}}", + realm: "{{.OAuth.Realm}}", + appName: "{{.OAuth.AppName}}" + }) + {{end}} window.ui = ui {{- if .AfterScript}} diff --git a/swagger_test.go b/swagger_test.go index 8ef6168..1774130 100644 --- a/swagger_test.go +++ b/swagger_test.go @@ -40,7 +40,6 @@ func (s *mockedSwag) ReadDoc() string { } func TestWrapHandler(t *testing.T) { - tests := []struct { RootFolder string InstanceName string @@ -110,6 +109,24 @@ func TestWrapHandler(t *testing.T) { } } +func TestConfigWithOAuth(t *testing.T) { + router := http.NewServeMux() + router.Handle("/", Handler(OAuth(&OAuthConfig{ + ClientId: "my-client-id", + Realm: "my-realm", + AppName: "My App Name", + }))) + + w := performRequest(http.MethodGet, "/index.html", router) + assert.Equal(t, 200, w.Code) + body := w.Body.String() + assert.Contains(t, body, `ui.initOAuth({ + clientId: "my-client-id", + realm: "my-realm", + appName: "My App Name" + })`) +} + func performRequest(method, target string, h http.Handler) *httptest.ResponseRecorder { r := httptest.NewRequest(method, target, nil) w := httptest.NewRecorder() @@ -199,7 +216,6 @@ func TestPersistAuthorization(t *testing.T) { } func TestConfigURL(t *testing.T) { - type fixture struct { desc string cfgfn func(c *Config) @@ -293,7 +309,6 @@ func TestConfigURL(t *testing.T) { } func TestUIConfigOptions(t *testing.T) { - type fixture struct { desc string cfg *Config @@ -390,7 +405,7 @@ func TestUIConfigOptions(t *testing.T) { DefaultModelsExpandDepth: ShowModel, }, exp: `window.onload = function() { - + const ui = SwaggerUIBundle({ url: "doc.json", deepLinking: true , @@ -406,9 +421,11 @@ func TestUIConfigOptions(t *testing.T) { SwaggerUIBundle.plugins.DownloadUrl ], layout: "StandaloneLayout", - defaultModelsExpandDepth: 1 + defaultModelsExpandDepth: 1 , + showExtensions: false }) + window.ui = ui }`, }, @@ -445,7 +462,7 @@ func TestUIConfigOptions(t *testing.T) { // Some plugin }); - + const ui = SwaggerUIBundle({ url: "swagger.json", deepLinking: false , @@ -466,9 +483,11 @@ func TestUIConfigOptions(t *testing.T) { onComplete: () => { window.ui.setBasePath('v3'); }, showExtensions: true, layout: "StandaloneLayout", - defaultModelsExpandDepth: -1 + defaultModelsExpandDepth: -1 , + showExtensions: true }) + window.ui = ui const someOtherCode = function(){ // Do something @@ -572,3 +591,23 @@ func TestShowExtensions(t *testing.T) { cfg = newConfig(ShowExtensions(false)) assert.False(t, cfg.ShowExtensions) } + +func TestOAuth(t *testing.T) { + var cfg Config + expected := OAuthConfig{ + ClientId: "my-client-id", + Realm: "my-realm", + AppName: "My App Name", + } + OAuth(&expected)(&cfg) + assert.Equal(t, expected.ClientId, cfg.OAuth.ClientId) + assert.Equal(t, expected.Realm, cfg.OAuth.Realm) + assert.Equal(t, expected.AppName, cfg.OAuth.AppName) +} + +func TestOAuthNil(t *testing.T) { + var cfg Config + var expected *OAuthConfig + OAuth(expected)(&cfg) + assert.Equal(t, expected, cfg.OAuth) +}