diff --git a/go.mod b/go.mod index 6ff4e51fc7d..ac775bad9e8 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/oleiade/reflections v1.0.0 github.com/opentracing/opentracing-go v1.0.2 github.com/ory/dockertest v3.3.2+incompatible - github.com/ory/fosite v0.28.0 + github.com/ory/fosite v0.29.0 github.com/ory/go-convenience v0.1.0 github.com/ory/graceful v0.1.0 github.com/ory/herodot v0.4.1 @@ -42,7 +42,7 @@ require ( github.com/urfave/negroni v1.0.0 github.com/ziutek/mymysql v1.5.4 // indirect go.uber.org/atomic v1.3.2 // indirect - golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b + golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 golang.org/x/net v0.0.0-20181029044818-c44066c5c816 // indirect golang.org/x/oauth2 v0.0.0-20181003184128-c57b0facaced gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect diff --git a/go.sum b/go.sum index 0150bbc2c4b..b3dc9dc8fdd 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,8 @@ github.com/ory/fosite v0.27.4 h1:+2Iu957COQM3vbWp5qjgq0W4icsjbtg+5y3AYJ87EjY= github.com/ory/fosite v0.27.4/go.mod h1:uttCRNB0lM7+BJFX7CC8Bqo9gAPrcpmA9Ezc80Trwuw= github.com/ory/fosite v0.28.0 h1:LxCkLXeU5PxYh9d/VbfGVn8GTKkSdOZfrHWdjmIE//c= github.com/ory/fosite v0.28.0/go.mod h1:uttCRNB0lM7+BJFX7CC8Bqo9gAPrcpmA9Ezc80Trwuw= +github.com/ory/fosite v0.29.0 h1:qFQfwy2YF1Bn5kgilT1LH3N0xOBvV865EXbj2bdxaoY= +github.com/ory/fosite v0.29.0/go.mod h1:0atSZmXO7CAcs6NPMI/Qtot8tmZYj04Nddoold4S2h0= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= github.com/ory/go-convenience v0.1.0/go.mod h1:uEY/a60PL5c12nYz4V5cHY03IBmwIAEm8TWB0yn9KNs= github.com/ory/graceful v0.1.0 h1:zilpYtcR5vp4GubV4bN2GFJewHaSkMFnnRiJxyH8FAc= @@ -282,6 +284,8 @@ golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc h1:F5tKCVGp+MUAHhKp5MZtGq golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b h1:Elez2XeF2p9uyVj0yEUDqQ56NFcDtcBNkYP7yv8YbUE= golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 h1:MQ/ZZiDsUapFFiMS+vzwXkCTeEKaum+Do5rINYJDmxc= +golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180611182652-db08ff08e862 h1:JZi6BqOZ+iSgmLWe6llhGrNnEnK+YB/MRkStwnEfbqM= golang.org/x/net v0.0.0-20180611182652-db08ff08e862/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 65c7e391295..7ce25179c65 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -27,6 +27,8 @@ import ( "testing" "time" + "github.com/ory/fosite/storage" + "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -118,6 +120,17 @@ func TestHelperRunner(t *testing.T, store ManagerTestSetup, k string) { t.Helper() if k != "memory" { t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAuthorizeCodes/db=%s", k), testHelperUniqueConstraints(store, k)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionsCommitAccessToken"), testFositeSqlStoreTransactionCommitAccessToken(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionsRollbackAccessToken"), testFositeSqlStoreTransactionRollbackAccessToken(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionCommitRefreshToken"), testFositeSqlStoreTransactionCommitRefreshToken(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionRollbackRefreshToken"), testFositeSqlStoreTransactionRollbackRefreshToken(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionCommitAuthorizeCode"), testFositeSqlStoreTransactionCommitAuthorizeCode(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionRollbackAuthorizeCode"), testFositeSqlStoreTransactionRollbackAuthorizeCode(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionCommitPKCERequest"), testFositeSqlStoreTransactionCommitPKCERequest(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionRollbackPKCERequest"), testFositeSqlStoreTransactionRollbackPKCERequest(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession"), testFositeSqlStoreTransactionCommitOpenIdConnectSession(store)) + t.Run(fmt.Sprint("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession"), testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store)) + } t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAuthorizeCodes/db=%s", k), testHelperCreateGetDeleteAuthorizeCodes(store)) t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAccessTokenSession/db=%s", k), testHelperCreateGetDeleteAccessTokenSession(store)) @@ -385,3 +398,228 @@ func testHelperFlushTokens(x ManagerTestSetup, lifespan time.Duration) func(t *t require.Error(t, err) } } + +func testFositeSqlStoreTransactionCommitAccessToken(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + { + doTestCommit(m, t, m.F.CreateAccessTokenSession, m.F.GetAccessTokenSession, m.F.RevokeAccessToken) + doTestCommit(m, t, m.F.CreateAccessTokenSession, m.F.GetAccessTokenSession, m.F.DeleteAccessTokenSession) + } + } +} + +func testFositeSqlStoreTransactionRollbackAccessToken(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + { + doTestRollback(m, t, m.F.CreateAccessTokenSession, m.F.GetAccessTokenSession, m.F.RevokeAccessToken) + doTestRollback(m, t, m.F.CreateAccessTokenSession, m.F.GetAccessTokenSession, m.F.DeleteAccessTokenSession) + } + } +} + +func testFositeSqlStoreTransactionCommitRefreshToken(m ManagerTestSetup) func(t *testing.T) { + + return func(t *testing.T) { + doTestCommit(m, t, m.F.CreateRefreshTokenSession, m.F.GetRefreshTokenSession, m.F.RevokeRefreshToken) + doTestCommit(m, t, m.F.CreateRefreshTokenSession, m.F.GetRefreshTokenSession, m.F.DeleteRefreshTokenSession) + } +} + +func testFositeSqlStoreTransactionRollbackRefreshToken(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + doTestRollback(m, t, m.F.CreateRefreshTokenSession, m.F.GetRefreshTokenSession, m.F.RevokeRefreshToken) + doTestRollback(m, t, m.F.CreateRefreshTokenSession, m.F.GetRefreshTokenSession, m.F.DeleteRefreshTokenSession) + } +} + +func testFositeSqlStoreTransactionCommitAuthorizeCode(m ManagerTestSetup) func(t *testing.T) { + + return func(t *testing.T) { + doTestCommit(m, t, m.F.CreateAuthorizeCodeSession, m.F.GetAuthorizeCodeSession, m.F.InvalidateAuthorizeCodeSession) + } +} + +func testFositeSqlStoreTransactionRollbackAuthorizeCode(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + doTestRollback(m, t, m.F.CreateAuthorizeCodeSession, m.F.GetAuthorizeCodeSession, m.F.InvalidateAuthorizeCodeSession) + } +} + +func testFositeSqlStoreTransactionCommitPKCERequest(m ManagerTestSetup) func(t *testing.T) { + + return func(t *testing.T) { + doTestCommit(m, t, m.F.CreatePKCERequestSession, m.F.GetPKCERequestSession, m.F.DeletePKCERequestSession) + } +} + +func testFositeSqlStoreTransactionRollbackPKCERequest(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + doTestRollback(m, t, m.F.CreatePKCERequestSession, m.F.GetPKCERequestSession, m.F.DeletePKCERequestSession) + } +} + +// OpenIdConnect tests can't use the helper functions, due to the signature of GetOpenIdConnectSession being +// different from the other getter methods +func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + txnStore, ok := m.F.(storage.Transactional) + require.True(t, ok) + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + testRequest := createTestRequest(signature) + err = m.F.CreateOpenIDConnectSession(ctx, signature, testRequest) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + res, err := m.F.GetOpenIDConnectSession(context.Background(), signature, testRequest) + // session should have been created successfully because Commit did not return an error + require.NoError(t, err) + AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") + + // test delete within a transaction + ctx, err = txnStore.BeginTX(context.Background()) + err = m.F.DeleteOpenIDConnectSession(ctx, signature) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + _, err = m.F.GetOpenIDConnectSession(context.Background(), signature, testRequest) + // Since commit worked for delete, we should get an error here. + require.Error(t, err) + } +} + +func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m ManagerTestSetup) func(t *testing.T) { + return func(t *testing.T) { + txnStore, ok := m.F.(storage.Transactional) + require.True(t, ok) + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + + signature := uuid.New() + testRequest := createTestRequest(signature) + err = m.F.CreateOpenIDConnectSession(ctx, signature, testRequest) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + ctx = context.Background() + _, err = m.F.GetOpenIDConnectSession(ctx, signature, testRequest) + // Since we rolled back above, the session should not exist and getting it should result in an error + require.Error(t, err) + + // create a new session, delete it, then rollback the delete. We should be able to then get it. + signature2 := uuid.New() + testRequest2 := createTestRequest(signature2) + err = m.F.CreateOpenIDConnectSession(ctx, signature2, testRequest2) + require.NoError(t, err) + _, err = m.F.GetOpenIDConnectSession(ctx, signature2, testRequest2) + require.NoError(t, err) + + ctx, err = txnStore.BeginTX(context.Background()) + err = m.F.DeleteOpenIDConnectSession(ctx, signature2) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + + require.NoError(t, err) + _, err = m.F.GetOpenIDConnectSession(context.Background(), signature2, testRequest2) + require.NoError(t, err) + } +} + +func doTestCommit(m ManagerTestSetup, t *testing.T, + createFn func(context.Context, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + + txnStore, ok := m.F.(storage.Transactional) + require.True(t, ok) + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + res, err := getFn(context.Background(), signature, &Session{}) + // token should have been created successfully because Commit did not return an error + require.NoError(t, err) + AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session") + + // testrevoke within a transaction + ctx, err = txnStore.BeginTX(context.Background()) + err = revokeFn(ctx, signature) + require.NoError(t, err) + err = txnStore.Commit(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + _, err = getFn(context.Background(), signature, &Session{}) + // Since commit worked for revoke, we should get an error here. + require.Error(t, err) +} + +func doTestRollback(m ManagerTestSetup, t *testing.T, + createFn func(context.Context, string, fosite.Requester) error, + getFn func(context.Context, string, fosite.Session) (fosite.Requester, error), + revokeFn func(context.Context, string) error, +) { + txnStore, ok := m.F.(storage.Transactional) + require.True(t, ok) + + ctx := context.Background() + ctx, err := txnStore.BeginTX(ctx) + require.NoError(t, err) + signature := uuid.New() + err = createFn(ctx, signature, createTestRequest(signature)) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + // Require a new context, since the old one contains the transaction. + ctx = context.Background() + _, err = getFn(ctx, signature, &Session{}) + // Since we rolled back above, the token should not exist and getting it should result in an error + require.Error(t, err) + + // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully. + signature2 := uuid.New() + err = createFn(ctx, signature2, createTestRequest(signature2)) + require.NoError(t, err) + _, err = getFn(ctx, signature2, &Session{}) + require.NoError(t, err) + + ctx, err = txnStore.BeginTX(context.Background()) + err = revokeFn(ctx, signature2) + require.NoError(t, err) + err = txnStore.Rollback(ctx) + require.NoError(t, err) + + _, err = getFn(context.Background(), signature2, &Session{}) + require.NoError(t, err) +} + +func createTestRequest(id string) *fosite.Request { + return &fosite.Request{ + ID: id, + RequestedAt: time.Now().UTC().Round(time.Second), + Client: &client.Client{ClientID: "foobar"}, + RequestedScope: fosite.Arguments{"fa", "ba"}, + GrantedScope: fosite.Arguments{"fa", "ba"}, + RequestedAudience: fosite.Arguments{"ad1", "ad2"}, + GrantedAudience: fosite.Arguments{"ad1", "ad2"}, + Form: url.Values{"foo": []string{"bar", "baz"}}, + Session: &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}}, + } +} diff --git a/oauth2/fosite_store_sql.go b/oauth2/fosite_store_sql.go index 87ad8b481a6..fabf9b95dfb 100644 --- a/oauth2/fosite_store_sql.go +++ b/oauth2/fosite_store_sql.go @@ -50,6 +50,13 @@ type FositeSQLStore struct { HashSignature bool } +type sqlxDB interface { + sqlx.ExecerContext + sqlx.Ext + NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + func NewFositeSQLStore(m client.Manager, db *sqlx.DB, l logrus.FieldLogger, @@ -84,6 +91,10 @@ var Migrations = map[string]*dbal.PackrMigrationSource{ }, true), } +type transactionKey int + +const txKey transactionKey = iota + var sqlParams = []string{ "signature", "request_id", @@ -200,6 +211,7 @@ func (s *FositeSQLStore) hashSignature(signature, table string) string { } func (s *FositeSQLStore) createSession(ctx context.Context, signature string, requester fosite.Requester, table string) error { + db := s.db(ctx) signature = s.hashSignature(signature, table) data, err := sqlSchemaFromRequest(signature, requester, s.L) @@ -213,17 +225,26 @@ func (s *FositeSQLStore) createSession(ctx context.Context, signature string, re strings.Join(sqlParams, ", "), ":"+strings.Join(sqlParams, ", :"), ) - if _, err := s.DB.NamedExecContext(ctx, query, data); err != nil { + if _, err := db.NamedExecContext(ctx, query, data); err != nil { return sqlcon.HandleError(err) } return nil } +func (s *FositeSQLStore) db(ctx context.Context) sqlxDB { + if tx, ok := ctx.Value(txKey).(*sqlx.Tx); ok { + return tx + } else { + return s.DB + } +} + func (s *FositeSQLStore) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table string) (fosite.Requester, error) { + db := s.db(ctx) signature = s.hashSignature(signature, table) var d sqlData - if err := s.DB.GetContext(ctx, &d, s.DB.Rebind(fmt.Sprintf("SELECT * FROM hydra_oauth2_%s WHERE signature=?", table)), signature); err == sql.ErrNoRows { + if err := db.GetContext(ctx, &d, db.Rebind(fmt.Sprintf("SELECT * FROM hydra_oauth2_%s WHERE signature=?", table)), signature); err == sql.ErrNoRows { return nil, errors.Wrap(fosite.ErrNotFound, "") } else if err != nil { return nil, sqlcon.HandleError(err) @@ -241,9 +262,10 @@ func (s *FositeSQLStore) findSessionBySignature(ctx context.Context, signature s } func (s *FositeSQLStore) deleteSession(ctx context.Context, signature string, table string) error { + db := s.db(ctx) signature = s.hashSignature(signature, table) - if _, err := s.DB.ExecContext(ctx, s.DB.Rebind(fmt.Sprintf("DELETE FROM hydra_oauth2_%s WHERE signature=?", table)), signature); err != nil { + if _, err := db.ExecContext(ctx, s.DB.Rebind(fmt.Sprintf("DELETE FROM hydra_oauth2_%s WHERE signature=?", table)), signature); err != nil { return sqlcon.HandleError(err) } return nil @@ -279,7 +301,8 @@ func (s *FositeSQLStore) GetAuthorizeCodeSession(ctx context.Context, signature } func (s *FositeSQLStore) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error { - if _, err := s.DB.ExecContext(ctx, s.DB.Rebind(fmt.Sprintf( + db := s.db(ctx) + if _, err := db.ExecContext(ctx, db.Rebind(fmt.Sprintf( "UPDATE hydra_oauth2_%s SET active=false WHERE signature=?", sqlTableCode, )), signature); err != nil { @@ -338,7 +361,8 @@ func (s *FositeSQLStore) RevokeAccessToken(ctx context.Context, id string) error } func (s *FositeSQLStore) revokeSession(ctx context.Context, id string, table string) error { - if _, err := s.DB.ExecContext(ctx, s.DB.Rebind(fmt.Sprintf("DELETE FROM hydra_oauth2_%s WHERE request_id=?", table)), id); err == sql.ErrNoRows { + db := s.db(ctx) + if _, err := db.ExecContext(ctx, db.Rebind(fmt.Sprintf("DELETE FROM hydra_oauth2_%s WHERE request_id=?", table)), id); err == sql.ErrNoRows { return errors.Wrap(fosite.ErrNotFound, "") } else if err != nil { return sqlcon.HandleError(err) @@ -355,3 +379,27 @@ func (s *FositeSQLStore) FlushInactiveAccessTokens(ctx context.Context, notAfter return nil } + +func (s *FositeSQLStore) BeginTX(ctx context.Context) (context.Context, error) { + if tx, err := s.DB.BeginTxx(ctx, nil); err != nil { + return ctx, err + } else { + return context.WithValue(ctx, txKey, tx), nil + } +} + +func (s *FositeSQLStore) Commit(ctx context.Context) error { + if tx, ok := ctx.Value(txKey).(*sqlx.Tx); !ok { + return errors.Wrap(fosite.ErrServerError, "commit failed: no transaction stored in context") + } else { + return tx.Commit() + } +} + +func (s *FositeSQLStore) Rollback(ctx context.Context) error { + if tx, ok := ctx.Value(txKey).(*sqlx.Tx); !ok { + return errors.Wrap(fosite.ErrServerError, "rollback failed: no transaction stored in context") + } else { + return tx.Rollback() + } +}