Skip to content

Commit

Permalink
user/manager: connector must exists when creating remote identity
Browse files Browse the repository at this point in the history
Add ConnectorConfigRepo to UserManager. When trying to create a
RemoteIdentity, validate that the connector ID exists.

Fixes #198
  • Loading branch information
ericchiang committed Dec 8, 2015
1 parent d518447 commit f43655a
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 14 deletions.
3 changes: 2 additions & 1 deletion cmd/dex-overlord/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ func main() {

userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
connCfgRepo := db.NewConnectorConfigRepo(dbc)
userManager := manager.NewUserManager(userRepo,
pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
pwiRepo, connCfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
kRepo, err := db.NewPrivateKeySetRepo(dbc, *useOldFormat, keySecrets.BytesSlice()...)
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions connector/config_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"io"
"os"

"github.com/coreos/dex/repo"
)

func newConnectorConfigsFromReader(r io.Reader) ([]ConnectorConfig, error) {
Expand Down Expand Up @@ -41,6 +43,19 @@ type memConnectorConfigRepo struct {
configs []ConnectorConfig
}

func NewConnectorConfigRepoFromConfigs(cfgs []ConnectorConfig) ConnectorConfigRepo {
return &memConnectorConfigRepo{configs: cfgs}
}

func (r *memConnectorConfigRepo) All() ([]ConnectorConfig, error) {
return r.configs, nil
}

func (r *memConnectorConfigRepo) GetConnectorByID(_ repo.Transaction, id string) (ConnectorConfig, error) {
for _, cfg := range r.configs {
if cfg.ConnectorID() == id {
return cfg, nil
}
}
return nil, ErrorNotFound
}
5 changes: 5 additions & 0 deletions connector/interface.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package connector

import (
"errors"
"html/template"
"net/http"
"net/url"

"github.com/coreos/dex/repo"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health"
)

var ErrorNotFound = errors.New("connector not found in repository")

type Connector interface {
ID() string
LoginURL(sessionKey, prompt string) (string, error)
Expand All @@ -34,4 +38,5 @@ type ConnectorConfig interface {

type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
}
26 changes: 26 additions & 0 deletions db/connector_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand All @@ -9,6 +10,7 @@ import (
"github.com/lib/pq"

"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
)

const (
Expand Down Expand Up @@ -91,6 +93,18 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) {
return cfgs, nil
}

func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) {
qt := pq.QuoteIdentifier(connectorConfigTableName)
q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt)
var c connectorConfigModel
if err := r.executor(tx).SelectOne(&c, q, id); err != nil {
if err == sql.ErrNoRows {
return nil, connector.ErrorNotFound
}
}
return c.ConnectorConfig()
}

func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {
insert := make([]interface{}, len(cfgs))
for i, cfg := range cfgs {
Expand Down Expand Up @@ -119,3 +133,15 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error {

return tx.Commit()
}

func (r *ConnectorConfigRepo) executor(tx repo.Transaction) gorp.SqlExecutor {
if tx == nil {
return r.dbMap
}

gorpTx, ok := tx.(*gorp.Transaction)
if !ok {
panic("wrong kind of transaction passed to a DB repo")
}
return gorpTx
}
71 changes: 71 additions & 0 deletions functional/repo/connector_repo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package repo

import (
"fmt"
"os"
"testing"

"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
)

type connectorConfigRepoFactory func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo

var makeTestConnectorConfigRepoFromConfigs connectorConfigRepoFactory

func init() {
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" {
makeTestConnectorConfigRepoFromConfigs = connector.NewConnectorConfigRepoFromConfigs
} else {
makeTestConnectorConfigRepoFromConfigs = makeTestConnectorConfigRepoMem(dsn)
}
}

func makeTestConnectorConfigRepoMem(dsn string) connectorConfigRepoFactory {
return func(cfgs []connector.ConnectorConfig) connector.ConnectorConfigRepo {
dbMap := initDB(dsn)

repo := db.NewConnectorConfigRepo(dbMap)
if err := repo.Set(cfgs); err != nil {
panic(fmt.Sprintf("Unable to set connector configs: %v", err))
}
return repo
}
}

func TestConnectorConfigRepoGetByID(t *testing.T) {
tests := []struct {
cfgs []connector.ConnectorConfig
id string
err error
}{
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
},
id: "local",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "local2",
},
{
cfgs: []connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local1"},
&connector.LocalConnectorConfig{ID: "local2"},
},
id: "foo",
err: connector.ErrorNotFound,
},
}

for i, tt := range tests {
repo := makeTestConnectorConfigRepoFromConfigs(tt.cfgs)
if _, err := repo.GetConnectorByID(nil, tt.id); err != tt.err {
t.Errorf("case %d: want=%v, got=%v", i, tt.err, err)
}
}
}
6 changes: 5 additions & 1 deletion integration/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/coreos/go-oidc/key"
"github.com/jonboulle/clockwork"

"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
Expand Down Expand Up @@ -47,7 +48,10 @@ func makeUserObjects(users []user.UserWithRemoteIdentities, passwords []user.Pas
ur := user.NewUserRepoFromUsers(users)
pwr := user.NewPasswordInfoRepoFromPasswordInfos(passwords)

um := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
ccr := connector.NewConnectorConfigRepoFromConfigs(
[]connector.ConnectorConfig{&connector.LocalConnectorConfig{ID: "local"}},
)
um := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
um.Clock = clock
return ur, pwr, um
}
4 changes: 2 additions & 2 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo()

txnFactory := repo.InMemTransactionFactory
userManager := manager.NewUserManager(userRepo, pwiRepo, txnFactory, manager.ManagerOptions{})
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
Expand Down Expand Up @@ -172,7 +172,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)

sm := session.NewSessionManager(sRepo, skRepo)
Expand Down
4 changes: 3 additions & 1 deletion server/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func makeTestFixtures() (*testFixtures, error) {
userRepo := user.NewUserRepoFromUsers(testUsers)
pwRepo := user.NewPasswordInfoRepoFromPasswordInfos(testPasswordInfos)
manager := manager.NewUserManager(userRepo, pwRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})

connConfigs := []connector.ConnectorConfig{
&connector.OIDCConnectorConfig{
Expand All @@ -112,6 +111,9 @@ func makeTestFixtures() (*testFixtures, error) {
ID: "local",
},
}
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)

manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})

sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
Expand Down
2 changes: 1 addition & 1 deletion test
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}

source ./build

TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api email"
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"

# user has not provided PKG override
Expand Down
6 changes: 5 additions & 1 deletion user/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/repo"
schema "github.com/coreos/dex/schema/workerschema"
"github.com/coreos/dex/user"
Expand Down Expand Up @@ -124,7 +125,10 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Password: []byte("password-2"),
},
})
mgr := manager.NewUserManager(ur, pwr, repo.InMemTransactionFactory, manager.ManagerOptions{})
ccr := connector.NewConnectorConfigRepoFromConfigs([]connector.ConnectorConfig{
&connector.LocalConnectorConfig{ID: "local"},
})
mgr := manager.NewUserManager(ur, pwr, ccr, repo.InMemTransactionFactory, manager.ManagerOptions{})
mgr.Clock = clock
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
Expand Down
21 changes: 17 additions & 4 deletions user/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/jonboulle/clockwork"

"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/user"
Expand All @@ -25,6 +26,7 @@ type UserManager struct {

userRepo user.UserRepo
pwRepo user.PasswordInfoRepo
connCfgRepo connector.ConnectorConfigRepo
begin repo.TransactionFactory
userIDGenerator user.UserIDGenerator
}
Expand All @@ -35,12 +37,13 @@ type ManagerOptions struct {
// variable policies
}

func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
func NewUserManager(userRepo user.UserRepo, pwRepo user.PasswordInfoRepo, connCfgRepo connector.ConnectorConfigRepo, txnFactory repo.TransactionFactory, options ManagerOptions) *UserManager {
return &UserManager{
Clock: clockwork.NewRealClock(),

userRepo: userRepo,
pwRepo: pwRepo,
connCfgRepo: connCfgRepo,
begin: txnFactory,
userIDGenerator: user.DefaultUserIDGenerator,
}
Expand Down Expand Up @@ -80,7 +83,7 @@ func (m *UserManager) CreateUser(usr user.User, hashedPassword user.Password, co
ConnectorID: connID,
ID: usr.ID,
}
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
Expand Down Expand Up @@ -141,7 +144,7 @@ func (m *UserManager) RegisterWithRemoteIdentity(email string, emailVerified boo
return "", err
}

if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
Expand Down Expand Up @@ -177,7 +180,7 @@ func (m *UserManager) RegisterWithPassword(email, plaintext, connID string) (str
ConnectorID: connID,
ID: usr.ID,
}
if err := m.userRepo.AddRemoteIdentity(tx, usr.ID, rid); err != nil {
if err := m.addRemoteIdentity(tx, usr.ID, rid); err != nil {
rollback(tx)
return "", err
}
Expand Down Expand Up @@ -338,6 +341,16 @@ func (m *UserManager) insertNewUser(tx repo.Transaction, email string, emailVeri
return usr, nil
}

func (m *UserManager) addRemoteIdentity(tx repo.Transaction, userID string, rid user.RemoteIdentity) error {
if _, err := m.connCfgRepo.GetConnectorByID(tx, rid.ConnectorID); err != nil {
return err
}
if err := m.userRepo.AddRemoteIdentity(tx, userID, rid); err != nil {
return err
}
return nil
}

func rollback(tx repo.Transaction) {
err := tx.Rollback()
if err != nil {
Expand Down
Loading

0 comments on commit f43655a

Please sign in to comment.