Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix registration of dev provider in Service.authMiddleware.Providers #201

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,39 +234,44 @@ func (s *Service) AddProviderWithUserAttributes(name, cid, csecret string, userA
L: s.logger,
UserAttributes: userAttributes,
}
s.addProvider(name, p)
s.addProviderByName(name, p)
}

func (s *Service) addProvider(name string, p provider.Params) {
func (s *Service) addProviderByName(name string, p provider.Params) {
var prov provider.Provider
switch strings.ToLower(name) {
case "github":
s.providers = append(s.providers, provider.NewService(provider.NewGithub(p)))
prov = provider.NewGithub(p)
case "google":
s.providers = append(s.providers, provider.NewService(provider.NewGoogle(p)))
prov = provider.NewGoogle(p)
case "facebook":
s.providers = append(s.providers, provider.NewService(provider.NewFacebook(p)))
prov = provider.NewFacebook(p)
case "yandex":
s.providers = append(s.providers, provider.NewService(provider.NewYandex(p)))
prov = provider.NewYandex(p)
case "battlenet":
s.providers = append(s.providers, provider.NewService(provider.NewBattlenet(p)))
prov = provider.NewBattlenet(p)
case "microsoft":
s.providers = append(s.providers, provider.NewService(provider.NewMicrosoft(p)))
prov = provider.NewMicrosoft(p)
case "twitter":
s.providers = append(s.providers, provider.NewService(provider.NewTwitter(p)))
prov = provider.NewTwitter(p)
case "patreon":
s.providers = append(s.providers, provider.NewService(provider.NewPatreon(p)))
prov = provider.NewPatreon(p)
case "dev":
s.providers = append(s.providers, provider.NewService(provider.NewDev(p)))
prov = provider.NewDev(p)
default:
return
}

s.addProvider(prov)
}

func (s *Service) addProvider(prov provider.Provider) {
s.providers = append(s.providers, provider.NewService(prov))
s.authMiddleware.Providers = s.providers
}

// AddProvider adds provider for given name
func (s *Service) AddProvider(name, cid, csecret string) {

p := provider.Params{
URL: s.opts.URL,
JwtService: s.jwtService,
Expand All @@ -277,8 +282,7 @@ func (s *Service) AddProvider(name, cid, csecret string) {
L: s.logger,
UserAttributes: map[string]string{},
}

s.addProvider(name, p)
s.addProviderByName(name, p)
}

// AddDevProvider with a custom host and port
Expand All @@ -292,7 +296,7 @@ func (s *Service) AddDevProvider(host string, port int) {
Port: port,
Host: host,
}
s.providers = append(s.providers, provider.NewService(provider.NewDev(p)))
s.addProvider(provider.NewDev(p))
}

// AddAppleProvider allow SignIn with Apple ID
Expand All @@ -311,7 +315,7 @@ func (s *Service) AddAppleProvider(appleConfig provider.AppleConfig, privKeyLoad
return fmt.Errorf("an AppleProvider creating failed: %w", err)
}

s.providers = append(s.providers, provider.NewService(appleProvider))
s.addProvider(appleProvider)
return nil
}

Expand All @@ -326,9 +330,7 @@ func (s *Service) AddCustomProvider(name string, client Client, copts provider.C
Csecret: client.Csecret,
L: s.logger,
}

s.providers = append(s.providers, provider.NewService(provider.NewCustom(name, p, copts)))
s.authMiddleware.Providers = s.providers
s.addProvider(provider.NewCustom(name, p, copts))
}

// AddDirectProvider adds provider with direct check against data store
Expand All @@ -342,8 +344,7 @@ func (s *Service) AddDirectProvider(name string, credChecker provider.CredChecke
CredChecker: credChecker,
AvatarSaver: s.avatarProxy,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddDirectProviderWithUserIDFunc adds provider with direct check against data store and sets custom UserIDFunc allows
Expand All @@ -359,8 +360,7 @@ func (s *Service) AddDirectProviderWithUserIDFunc(name string, credChecker provi
AvatarSaver: s.avatarProxy,
UserIDFunc: ufn,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddVerifProvider adds provider user's verification sent by sender
Expand All @@ -375,14 +375,12 @@ func (s *Service) AddVerifProvider(name, msgTmpl string, sender provider.Sender)
Template: msgTmpl,
UseGravatar: s.useGravatar,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddCustomHandler adds user-defined self-implemented handler of auth provider
func (s *Service) AddCustomHandler(handler provider.Provider) {
s.providers = append(s.providers, provider.NewService(handler))
s.authMiddleware.Providers = s.providers
func (s *Service) AddCustomHandler(p provider.Provider) {
s.addProvider(p)
}

// DevAuth makes dev oauth2 server, for testing and development only!
Expand Down
78 changes: 46 additions & 32 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ func TestProvider(t *testing.T) {
_, err := svc.Provider("some provider")
assert.EqualError(t, err, "provider some provider not found")

svc.AddProvider("dev", "cid", "csecret")
svc.AddProviderWithUserAttributes("dev", "cid", "csecret", provider.UserAttributes{"attrName": "attrValue"})
svc.AddProvider("github", "cid", "csecret")
svc.AddProvider("google", "cid", "csecret")
svc.AddProvider("facebook", "cid", "csecret")
svc.AddProvider("yandex", "cid", "csecret")
svc.AddProvider("microsoft", "cid", "csecret")
svc.AddProvider("twitter", "cid", "csecret")
svc.AddProvider("battlenet", "cid", "csecret")
svc.AddProvider("patreon", "cid", "csecret")
svc.AddProvider("bad", "cid", "csecret")
Expand All @@ -72,14 +73,15 @@ func TestProvider(t *testing.T) {
assert.Equal(t, "cid", op.Cid)
assert.Equal(t, "csecret", op.Csecret)
assert.Equal(t, "go-pkgz/auth", op.Issuer)
assert.Equal(t, provider.UserAttributes{"attrName": "attrValue"}, op.Params.UserAttributes)

p, err = svc.Provider("github")
assert.NoError(t, err)
op = p.Provider.(provider.Oauth2Handler)
assert.Equal(t, "github", op.Name())

pp := svc.Providers()
assert.Equal(t, 9, len(pp))
assert.Equal(t, 10, len(pp))

ch, err := svc.Provider("telegramBotMySiteCom")
assert.NoError(t, err)
Expand Down Expand Up @@ -227,7 +229,11 @@ func TestIntegrationAvatar(t *testing.T) {
}

func TestIntegrationList(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddProvider("github", "cid", "csec")
// add go-oauth2/oauth2 provider
svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{})
})
defer teardown()

resp, err := http.Get("http://127.0.0.1:8089/auth/list")
Expand All @@ -237,7 +243,7 @@ func TestIntegrationList(t *testing.T) {

b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, `["dev","github","custom123","direct","direct_custom","email"]`+"\n", string(b))
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

func TestIntegrationUserInfo(t *testing.T) {
Expand Down Expand Up @@ -336,7 +342,11 @@ func TestBadRequests(t *testing.T) {
}

func TestDirectProvider(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}))
})
defer teardown()

// login
Expand Down Expand Up @@ -374,19 +384,28 @@ func TestDirectProvider(t *testing.T) {
}

func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProviderWithUserIDFunc("directCustom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
func(user string, r *http.Request) string {
return "blah"
},
)
})
defer teardown()

// login
jar, err := cookiejar.New(nil)
require.Nil(t, err)
client := &http.Client{Jar: jar, Timeout: 5 * time.Second}
resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad")
resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)

resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password")
resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
Expand All @@ -396,7 +415,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
t.Logf("resp %s", string(body))
t.Logf("headers: %+v", resp.Header)

assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)
assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)

require.Equal(t, 2, len(resp.Cookies()))
assert.Equal(t, "JWT", resp.Cookies()[0].Name)
Expand All @@ -412,7 +431,9 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
}

func TestVerifProvider(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddVerifProvider("email", "{{.Token}}", &sender)
})
defer teardown()

// login
Expand Down Expand Up @@ -488,7 +509,16 @@ func TestStatus(t *testing.T) {

}

func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unparam
func TestDevAuthServerWithoutDevProvider(t *testing.T) {
svc := NewService(Opts{})
assert.NotNil(t, svc)

_, err := svc.DevAuth()
require.NotNil(t, err)
assert.EqualError(t, err, "dev provider not registered: provider dev not found")
}

func prepService(t *testing.T, providerConfigFunctions ...func(svc *Service)) (svc *Service, teardown func()) { //nolint unparam

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Expand All @@ -509,28 +539,12 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara
}

svc = NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
svc.AddProvider("github", "cid", "csec") // add github provider

// add go-oauth2/oauth2 provider
svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{})

// add direct provider
svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}))

// add direct provider with custom user id func
svc.AddDirectProviderWithUserIDFunc("direct_custom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
func(user string, r *http.Request) string {
return "blah"
},
)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084

svc.AddVerifProvider("email", "{{.Token}}", &sender)
for _, f := range providerConfigFunctions {
f(svc)
}

// run dev/test oauth2 server on :18084
devAuth, err := svc.DevAuth()
Expand All @@ -546,7 +560,7 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara
_, _ = w.Write([]byte("open route, no token needed\n"))
}))
mux.Handle("/private", m.Auth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("open route, no token needed\n"))
_, _ = w.Write([]byte("protected route, authenticated with token\n"))
})))

// setup auth routes
Expand Down
Loading
Loading