Skip to content

Commit

Permalink
Database CA - trusted cluster changes (#11615)
Browse files Browse the repository at this point in the history
* Add DatabaseCA to trusted cluster rotation

* Add DatabaseCA test

* Refactor NewCertAuthorityWatcher - remove code duplication.

* Refactor TestDatabaseRotateTrustedCluster.

* Add comments to a DatabaseCA test.

* Improve database rotation test.

* Clean up test
Add comments

* Rename variable

* Fix logger in DB integration test.
  • Loading branch information
jakule authored Apr 4, 2022
1 parent 4a73e20 commit dbebf81
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 85 deletions.
187 changes: 183 additions & 4 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import (
"fmt"
"net"
"net/http"
"strings"
"testing"
"time"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
Expand All @@ -41,12 +43,12 @@ import (
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"

"github.com/google/uuid"
"github.com/jackc/pgconn"
"github.com/jonboulle/clockwork"
"github.com/siddontang/go-mysql/client"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
)
Expand Down Expand Up @@ -118,6 +120,182 @@ func TestDatabaseAccessPostgresLeafCluster(t *testing.T) {
require.NoError(t, err)
}

func TestDatabaseRotateTrustedCluster(t *testing.T) {
pack := setupDatabaseTest(t,
// set tighter rotation intervals
withLeafConfig(func(config *service.Config) {
config.PollingPeriod = 5 * time.Second
config.RotationConnectionInterval = 2 * time.Second
}),
withRootConfig(func(config *service.Config) {
config.PollingPeriod = 5 * time.Second
config.RotationConnectionInterval = 2 * time.Second
}))
pack.waitForLeaf(t)

var (
ctx = context.Background()
rootCluster = pack.root.cluster
authServer = rootCluster.Process.GetAuthServer()
clusterRootName = rootCluster.Secrets.SiteName
clusterLeafName = pack.leaf.cluster.Secrets.SiteName
)

pw := phaseWatcher{
clusterRootName: clusterRootName,
pollingPeriod: rootCluster.Process.Config.PollingPeriod,
clock: pack.clock,
siteAPI: rootCluster.GetSiteAPI(clusterLeafName),
certType: types.DatabaseCA,
}

currentDbCA, err := pack.root.dbAuthClient.GetCertAuthority(ctx, types.CertAuthID{
Type: types.DatabaseCA,
DomainName: clusterRootName,
}, false)
require.NoError(t, err)

rotationPhases := []string{types.RotationPhaseInit, types.RotationPhaseUpdateClients,
types.RotationPhaseUpdateServers, types.RotationPhaseStandby}

waitForEvent := func(process *service.TeleportProcess, event string) {
eventC := make(chan service.Event, 1)
process.WaitForEvent(context.TODO(), event, eventC)
select {
case <-eventC:

case <-time.After(20 * time.Second):
t.Fatalf("timeout waiting for service to broadcast event %s", event)
}
}

for _, phase := range rotationPhases {
errChan := make(chan error, 1)

go func() {
errChan <- pw.waitForPhase(phase, func() error {
return authServer.RotateCertAuthority(ctx, auth.RotateRequest{
Type: types.DatabaseCA,
TargetPhase: phase,
Mode: types.RotationModeManual,
})
})
}()

err = <-errChan

if err != nil && strings.Contains(err.Error(), "context deadline exceeded") {
// TODO(jakule): Workaround for CertAuthorityWatcher failing to get the correct rotation status.
// Query auth server directly to see if the incorrect rotation status is a rotation or watcher problem.
dbCA, err := pack.leaf.cluster.Process.GetAuthServer().GetCertAuthority(ctx, types.CertAuthID{
Type: types.DatabaseCA,
DomainName: clusterRootName,
}, false)
require.NoError(t, err)
require.Equal(t, dbCA.GetRotation().Phase, phase)
} else {
require.NoError(t, err)
}

// Reload doesn't happen on Init
if phase == types.RotationPhaseInit {
continue
}

waitForEvent(pack.root.cluster.Process, service.TeleportReloadEvent)
waitForEvent(pack.leaf.cluster.Process, service.TeleportReadyEvent)

pack.waitForLeaf(t)
}

rotatedDbCA, err := authServer.GetCertAuthority(ctx, types.CertAuthID{
Type: types.DatabaseCA,
DomainName: clusterRootName,
}, false)
require.NoError(t, err)

// Sanity check. Check if the CA was rotated.
require.NotEqual(t, currentDbCA.GetActiveKeys(), rotatedDbCA.GetActiveKeys())

// Connect to the database service in leaf cluster via root cluster.
dbClient, err := postgres.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: net.JoinHostPort(Loopback, pack.root.cluster.GetPortWeb()), // Connecting via root cluster.
Cluster: pack.leaf.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.leaf.postgresService.Name,
Protocol: pack.leaf.postgresService.Protocol,
Username: "postgres",
Database: "test",
},
})
require.NoError(t, err)

// Execute a query.
result, err := dbClient.Exec(context.Background(), "select 1").ReadAll()
require.NoError(t, err)
require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, result)
require.Equal(t, uint32(1), pack.leaf.postgres.QueryCount())
require.Equal(t, uint32(0), pack.root.postgres.QueryCount())

// Disconnect.
err = dbClient.Close(context.Background())
require.NoError(t, err)
}

// phaseWatcher holds all arguments required by rotation watcher.
type phaseWatcher struct {
clusterRootName string
pollingPeriod time.Duration
clock clockwork.Clock
siteAPI types.Events
certType types.CertAuthType
}

// waitForPhase waits until rootCluster cluster detects the rotation. fn is a rotation function that is called after
// watcher is created.
func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
ctx, cancel := context.WithTimeout(context.Background(), p.pollingPeriod*10)
defer cancel()

watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: p.clock,
Client: p.siteAPI,
},
WatchCertTypes: []types.CertAuthType{p.certType},
})
if err != nil {
return err
}
defer watcher.Close()

if err := fn(); err != nil {
return trace.Wrap(err)
}

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, ctx.Err())
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == p.clusterRootName &&
ca.GetType() == p.certType &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
}

// TestDatabaseAccessMySQLRootCluster tests a scenario where a user connects
// to a MySQL database running in a root cluster.
func TestDatabaseAccessMySQLRootCluster(t *testing.T) {
Expand Down Expand Up @@ -1043,15 +1221,16 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
case <-time.Tick(500 * time.Millisecond):
servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
logrus.WithError(err).Debugf("Leaf cluster access point is unavailable.")
// Use root logger as we need a configured logger instance and the root cluster have one.
p.root.cluster.log.WithError(err).Debugf("Leaf cluster access point is unavailable.")
continue
}
if !containsDB(servers, p.leaf.mysqlService.Name) {
logrus.WithError(err).Debugf("Leaf db service %q is unavailable.", p.leaf.mysqlService.Name)
p.root.cluster.log.WithError(err).Debugf("Leaf db service %q is unavailable.", p.leaf.mysqlService.Name)
continue
}
if !containsDB(servers, p.leaf.postgresService.Name) {
logrus.WithError(err).Debugf("Leaf db service %q is unavailable.", p.leaf.postgresService.Name)
p.root.cluster.log.WithError(err).Debugf("Leaf db service %q is unavailable.", p.leaf.postgresService.Name)
continue
}
return
Expand Down
2 changes: 1 addition & 1 deletion integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func (s *InstanceSecrets) GetIdentity() *auth.Identity {
return i
}

// GetSiteAPI() is a helper which returns an API endpoint to a site with
// GetSiteAPI is a helper which returns an API endpoint to a site with
// a given name. i endpoint implements HTTP-over-SSH access to the
// site's auth server.
func (i *TeleInstance) GetSiteAPI(siteName string) auth.ClientI {
Expand Down
36 changes: 18 additions & 18 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2164,7 +2164,7 @@ func trustedClusters(t *testing.T, suite *integrationTestSuite, test trustedClus
{Remote: mainOps, Local: []string{auxDevs}},
})

// modify trusted cluster resource name so it would not
// modify trusted cluster resource name, so it would not
// match the cluster name to check that it does not matter
trustedCluster.SetName(main.Secrets.SiteName + "-cluster")

Expand Down Expand Up @@ -2263,7 +2263,7 @@ func trustedClusters(t *testing.T, suite *integrationTestSuite, test trustedClus
}
require.Error(t, err, "expected tunnel to close and SSH client to start failing")

// remove trusted cluster from aux cluster side, and recrete right after
// remove trusted cluster from aux cluster side, and recreate right after
// this should re-establish connection
err = aux.Process.GetAuthServer().DeleteTrustedCluster(ctx, trustedCluster.GetName())
require.NoError(t, err)
Expand Down Expand Up @@ -2305,7 +2305,7 @@ func waitForClusters(tun reversetunnel.Server, expected int) func() bool {
return false
}

// Check the expected number of clusters are connected and they have all
// Check the expected number of clusters are connected, and they have all
// connected with the past 10 seconds.
if len(clusters) >= expected {
for _, cluster := range clusters {
Expand Down Expand Up @@ -2374,7 +2374,7 @@ func testTrustedTunnelNode(t *testing.T, suite *integrationTestSuite) {
{Remote: mainDevs, Local: []string{auxDevs}},
})

// modify trusted cluster resource name so it would not
// modify trusted cluster resource name, so it would not
// match the cluster name to check that it does not matter
trustedCluster.SetName(main.Secrets.SiteName + "-cluster")

Expand Down Expand Up @@ -3732,7 +3732,7 @@ func testRotateSuccess(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reload
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
defer svc.Shutdown(context.TODO())

Expand Down Expand Up @@ -3763,7 +3763,7 @@ func testRotateSuccess(t *testing.T, suite *integrationTestSuite) {
t.Logf("Cert authority: %v", auth.CertAuthorityInfo(hostCA))

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
defer svc.Shutdown(context.TODO())

Expand Down Expand Up @@ -3792,7 +3792,7 @@ func testRotateSuccess(t *testing.T, suite *integrationTestSuite) {
t.Logf("Cert authority: %v", auth.CertAuthorityInfo(hostCA))

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
defer svc.Shutdown(context.TODO())

Expand Down Expand Up @@ -3878,7 +3878,7 @@ func testRotateRollback(t *testing.T, s *integrationTestSuite) {
require.NoError(t, err)

// wait until service reload
svc, err = s.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

cfg := ClientConfig{
Expand All @@ -3904,7 +3904,7 @@ func testRotateRollback(t *testing.T, s *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = s.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

t.Logf("Service reloaded. Setting rotation state to %q.", types.RotationPhaseRollback)
Expand All @@ -3917,7 +3917,7 @@ func testRotateRollback(t *testing.T, s *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = s.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

// old client works
Expand Down Expand Up @@ -4064,7 +4064,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchHostCA: true,
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
if err != nil {
return err
Expand Down Expand Up @@ -4101,7 +4101,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateClients)
Expand All @@ -4121,7 +4121,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateServers)
Expand Down Expand Up @@ -4149,7 +4149,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {

// wait until service reloaded
t.Log("Waiting for service reload.")
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")

Expand Down Expand Up @@ -4257,7 +4257,7 @@ func testRotateChangeSigningAlg(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reload
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

t.Logf("Rotation phase: %q.", types.RotationPhaseUpdateServers)
Expand All @@ -4268,7 +4268,7 @@ func testRotateChangeSigningAlg(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

t.Logf("rotation phase: %q", types.RotationPhaseStandby)
Expand All @@ -4279,7 +4279,7 @@ func testRotateChangeSigningAlg(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// wait until service reloaded
svc, err = suite.waitForReload(serviceC, svc)
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)

return svc
Expand Down Expand Up @@ -4367,7 +4367,7 @@ func waitForProcessStart(serviceC chan *service.TeleportProcess) (*service.Telep
// 2. old service, if present to shut down
//
// this helper function allows to serialize tests for reloads.
func (s *integrationTestSuite) waitForReload(serviceC chan *service.TeleportProcess, old *service.TeleportProcess) (*service.TeleportProcess, error) {
func waitForReload(serviceC chan *service.TeleportProcess, old *service.TeleportProcess) (*service.TeleportProcess, error) {
var svc *service.TeleportProcess
select {
case svc = <-serviceC:
Expand Down
Loading

0 comments on commit dbebf81

Please sign in to comment.