From a77881a0f23f72e93ea679e9c3f9539209f3974c Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 10 May 2022 18:55:15 +0000 Subject: [PATCH 1/2] ccl/sqlproxyccl: invoke rebalancing logic during RUNNING pod events This commit invokes the rebalancing logic during RUNNING pod events as part of the pod watcher. Since the rebalancing logic depends on the tenant directory, the pod watcher will now only emit events once the directory has been updated. This is done for better responsiveness, i.e. the moment a new SQL pod gets added, we would like to rebalance all connections to the tenant. Note that the Watch endpoint on the tenant directory server currently emits events in multiple cases: changes to load, and changes to pod (added/modified/ deleted). The plan is to update the tenant directory server to only emit events for pod updates. The next commit will rate limit the number of times the rebalancing logic for a given tenant can be called. At the same time, we introduce a new test static directory server which does not automatically spin up tenants for us (i.e. SQL pods for tenants can now be managed manually, giving more control to tests). Release note: None --- pkg/ccl/sqlproxyccl/balancer/balancer.go | 103 +++-- pkg/ccl/sqlproxyccl/proxy_handler.go | 63 ++- pkg/ccl/sqlproxyccl/proxy_handler_test.go | 174 +++++-- pkg/ccl/sqlproxyccl/tenant/BUILD.bazel | 4 + pkg/ccl/sqlproxyccl/tenant/directory_cache.go | 11 +- .../tenant/directory_cache_test.go | 210 +++++---- pkg/ccl/sqlproxyccl/tenant/entry.go | 1 + pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel | 1 + .../tenantdirsvr/test_directory_svr.go | 2 +- .../tenantdirsvr/test_simple_directory_svr.go | 2 +- .../tenantdirsvr/test_static_directory_svr.go | 432 ++++++++++++++++++ .../sqlccl/show_transfer_state_test.go | 4 +- pkg/testutils/lint/lint_test.go | 1 + 13 files changed, 818 insertions(+), 190 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer.go b/pkg/ccl/sqlproxyccl/balancer/balancer.go index b15cfb3d8d64..a67b4ffea1b8 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer.go @@ -50,15 +50,15 @@ const ( // NOTE: This must be between 0 and 1 inclusive. rebalancePercentDeviation = 0.15 - // rebalanceRate defines the rate of rebalancing assignments across SQL - // pods. This rate applies to both RUNNING and DRAINING pods. For example, - // consider the case where the rate is 0.50; if we have decided that we need - // to move 15 assignments away from a particular pod, only 7 pods will be - // moved at a time. + // defaultRebalanceRate defines the rate of rebalancing assignments across + // SQL pods. This rate applies to both RUNNING and DRAINING pods. For + // example, consider the case where the rate is 0.50; if we have decided + // that we need to move 15 assignments away from a particular pod, only 7 + // pods will be moved at a time. // // NOTE: This must be between 0 and 1 inclusive. 0 means no rebalancing // will occur. - rebalanceRate = 0.50 + defaultRebalanceRate = 0.50 // defaultMaxConcurrentRebalances represents the maximum number of // concurrent rebalance requests that are being processed. This effectively @@ -78,6 +78,7 @@ type balancerOptions struct { maxConcurrentRebalances int noRebalanceLoop bool timeSource timeutil.TimeSource + rebalanceRate float32 } // Option defines an option that can be passed to NewBalancer in order to @@ -109,6 +110,14 @@ func TimeSource(ts timeutil.TimeSource) Option { } } +// RebalanceRate defines the rate of rebalancing across pods. Set to -1 to +// disable rebalancing (i.e. connection transfers). +func RebalanceRate(rate float32) Option { + return func(opts *balancerOptions) { + opts.rebalanceRate = rate + } +} + // Balancer handles load balancing of SQL connections within the proxy. // All methods on the Balancer instance are thread-safe. type Balancer struct { @@ -140,6 +149,9 @@ type Balancer struct { // timeutil.DefaultTimeSource. Override with the TimeSource() option when // calling NewBalancer. timeSource timeutil.TimeSource + + // rebalanceRate represents the rate of rebalancing connections. + rebalanceRate float32 } // NewBalancer constructs a new Balancer instance that is responsible for @@ -162,6 +174,12 @@ func NewBalancer( if options.timeSource == nil { options.timeSource = timeutil.DefaultTimeSource{} } + if options.rebalanceRate == 0 { + options.rebalanceRate = defaultRebalanceRate + } + if options.rebalanceRate == -1 { + options.rebalanceRate = 0 + } // Ensure that ctx gets cancelled on stopper's quiescing. ctx, _ = stopper.WithCancelOnQuiesce(ctx) @@ -178,6 +196,7 @@ func NewBalancer( queue: q, processSem: semaphore.New(options.maxConcurrentRebalances), timeSource: options.timeSource, + rebalanceRate: options.rebalanceRate, } b.connTracker, err = NewConnTracker(ctx, b.stopper, b.timeSource) if err != nil { @@ -300,57 +319,57 @@ func (b *Balancer) rebalanceLoop(ctx context.Context) { } // rebalance attempts to rebalance connections for all tenants within the proxy. -// -// TODO(jaylim-crl): Update this to support rebalancing a single tenant. That -// way, the pod watcher could call this to rebalance a single tenant. We may -// also want to rate limit the number of rebalances per tenant for requests -// coming from the pod watcher. func (b *Balancer) rebalance(ctx context.Context) { // getTenantIDs ensures that tenants will have at least one connection. tenantIDs := b.connTracker.getTenantIDs() - for _, tenantID := range tenantIDs { - tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) - if err != nil { - // This case shouldn't really occur unless there's a bug in the - // directory server (e.g. deleted pod events, but the pod is still - // alive). - log.Errorf(ctx, "could not lookup pods for tenant %s: %v", tenantID, err.Error()) - continue - } + b.RebalanceTenant(ctx, tenantID) + } +} - // Construct a map so we could easily retrieve the pod by address. - podMap := make(map[string]*tenant.Pod) - var hasRunningPod bool - for _, pod := range tenantPods { - podMap[pod.Addr] = pod +// RebalanceTenant rebalances connections for the given tenant. If no RUNNING +// pod exists for the given tenant, this is a no-op. +// +// TODO(jaylim-crl): Rate limit the number of rebalances per tenant for requests +// coming from the pod watcher. +func (b *Balancer) RebalanceTenant(ctx context.Context, tenantID roachpb.TenantID) { + tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) + if err != nil { + log.Errorf(ctx, "could not rebalance tenant %s: %v", tenantID, err.Error()) + return + } - if pod.State == tenant.RUNNING { - hasRunningPod = true - } - } + // Construct a map so we could easily retrieve the pod by address. + podMap := make(map[string]*tenant.Pod) + var hasRunningPod bool + for _, pod := range tenantPods { + podMap[pod.Addr] = pod - // Only attempt to rebalance if we have a RUNNING pod. In theory, this - // case would happen if we're scaling down from 1 to 0, which in that - // case, we can't transfer connections anywhere. Practically, we will - // never scale a tenant from 1 to 0 if there are still active - // connections, so this case should not occur. - if !hasRunningPod { - continue + if pod.State == tenant.RUNNING { + hasRunningPod = true } + } - activeList, idleList := b.connTracker.listAssignments(tenantID) - b.rebalancePartition(podMap, activeList) - b.rebalancePartition(podMap, idleList) + // Only attempt to rebalance if we have a RUNNING pod. In theory, this + // case would happen if we're scaling down from 1 to 0, which in that + // case, we can't transfer connections anywhere. Practically, we will + // never scale a tenant from 1 to 0 if there are still active + // connections, so this case should not occur. + if !hasRunningPod { + return } + + activeList, idleList := b.connTracker.listAssignments(tenantID) + b.rebalancePartition(podMap, activeList) + b.rebalancePartition(podMap, idleList) } // rebalancePartition rebalances the given assignments partition. func (b *Balancer) rebalancePartition( pods map[string]*tenant.Pod, assignments []*ServerAssignment, ) { - // Nothing to do here. - if len(pods) == 0 || len(assignments) == 0 { + // Nothing to do here if there are no assignments, or only one pod. + if len(pods) <= 1 || len(assignments) == 0 { return } @@ -371,7 +390,7 @@ func (b *Balancer) rebalancePartition( // // NOTE: Elements in the list may be shuffled around once this method returns. func (b *Balancer) enqueueRebalanceRequests(list []*ServerAssignment) { - toMoveCount := int(math.Ceil(float64(len(list)) * float64(rebalanceRate))) + toMoveCount := int(math.Ceil(float64(len(list)) * float64(b.rebalanceRate))) partition, _ := partitionNRandom(list, toMoveCount) for _, a := range partition { b.queue.enqueue(&rebalanceRequest{ diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 9c85db56e4ce..c438da17b3bd 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -105,9 +105,13 @@ type ProxyOptions struct { // proxy. dirOpts []tenant.DirOption - // directoryServer represents the in-memory directory server that is - // created whenever a routing rule is used. + // directoryServer represents the directory server that will be used + // by the proxy handler. If unset, initializing the proxy handler will + // create one, and populate this value. directoryServer tenant.DirectoryServer + + // balancerOpts is used to customize the balancer created by the proxy. + balancerOpts []balancer.Option } } @@ -184,8 +188,31 @@ func newProxyHandler( throttler.WithBaseDelay(handler.ThrottleBaseDelay), ) + // TODO(jaylim-crl): Clean up how we start different types of directory + // servers. We could have two options: remote or local. Local servers that + // are using the in-memory implementation should only be used for testing + // only. For production use-cases, we will need to update that to listen + // on an actual network address, but there are no plans to support that at + // the moment. var conn *grpc.ClientConn - if handler.DirectoryAddr != "" { + if handler.testingKnobs.directoryServer != nil { + // TODO(jaylim-crl): For now, only support the static version. We should + // make this part of a LocalDirectoryServer interface for us to grab the + // in-memory listener. + directoryServer, ok := handler.testingKnobs.directoryServer.(*tenantdirsvr.TestStaticDirectoryServer) + if !ok { + return nil, errors.New("unsupported test directory server") + } + conn, err = grpc.DialContext( + ctx, + "", + grpc.WithContextDialer(directoryServer.DialerFunc), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, err + } + } else if handler.DirectoryAddr != "" { conn, err = grpc.Dial( handler.DirectoryAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -235,7 +262,11 @@ func newProxyHandler( balancerMetrics := balancer.NewMetrics() registry.AddMetricStruct(balancerMetrics) - handler.balancer, err = balancer.NewBalancer(ctx, stopper, balancerMetrics, handler.directoryCache) + var balancerOpts []balancer.Option + if handler.testingKnobs.balancerOpts != nil { + balancerOpts = append(balancerOpts, handler.testingKnobs.balancerOpts...) + } + handler.balancer, err = balancer.NewBalancer(ctx, stopper, balancerMetrics, handler.directoryCache, balancerOpts...) if err != nil { return nil, err } @@ -406,22 +437,26 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } // startPodWatcher runs on a background goroutine and listens to pod change -// notifications. When a pod enters the DRAINING state, connections to that pod -// are subject to an idle timeout that closes them after a short period of -// inactivity. If a pod transitions back to the RUNNING state or to the DELETING -// state, then the idle timeout needs to be cleared. -// -// TODO(jaylim-crl): Update comment above. +// notifications. When a pod transitions into the DRAINING state, a rebalance +// operation will be attempted for that particular pod's tenant. func (handler *proxyHandler) startPodWatcher(ctx context.Context, podWatcher chan *tenant.Pod) { for { select { case <-ctx.Done(): return - case <-podWatcher: - // TODO(jaylim-crl): Invoke rebalance logic here whenever we see - // a new SQL pod. + case pod := <-podWatcher: + // For better responsiveness, we only care about RUNNING pods. + // DRAINING pods can wait until the next rebalance tick. // - // Do nothing for now. + // Note that there's no easy way to tell whether a SQL pod is new + // since this may race with fetchPodsLocked in tenantEntry, so we + // will just attempt to rebalance whenever we see a RUNNING pod + // event. In most cases, this should only occur when a new SQL pod + // gets added (i.e. stamped), or a DRAINING pod transitions to a + // RUNNING pod. + if pod.State == tenant.RUNNING && pod.TenantID != 0 { + handler.balancer.RebalanceTenant(ctx, roachpb.MakeTenantID(pod.TenantID)) + } } } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 26c0f8967c6a..192d5502ee1a 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -16,6 +16,7 @@ import ( "io/ioutil" "net" "os" + "sort" "strings" "sync" "sync/atomic" @@ -25,6 +26,7 @@ import ( "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/balancer" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr" @@ -763,6 +765,121 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestPodWatcher(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + // Start KV server, and enable session migration. + params, _ := tests.CreateTestServerParams() + s, mainDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + _, err := mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") + require.NoError(t, err) + + // Start four SQL pods for the test tenant. + var addresses []string + tenantID := serverutils.TestTenantID() + const podCount = 4 + for i := 0; i < podCount; i++ { + params := tests.CreateTestTenantParams(tenantID) + // The first SQL pod will create the tenant keyspace in the host. + if i != 0 { + params.Existing = true + } + tenant, tenantDB := serverutils.StartTenant(t, s, params) + tenant.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant.Stopper().Stop(ctx) + + // Create a test user. We only need to do it once. + if i == 0 { + _, err = tenantDB.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") + require.NoError(t, err) + _, err = tenantDB.Exec("GRANT admin TO testuser") + require.NoError(t, err) + } + tenantDB.Close() + + addresses = append(addresses, tenant.SQLAddr()) + } + + // Register only 3 SQL pods in the directory server. We will add the 4th + // once the watcher has been established. + tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), nil /* timeSource */) + tds.CreateTenant(tenantID, "tenant-cluster") + for i := 0; i < 3; i++ { + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: addresses[i], + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + }) + } + require.NoError(t, tds.Start(ctx)) + + opts := &ProxyOptions{SkipVerify: true} + opts.testingKnobs.directoryServer = tds + opts.testingKnobs.balancerOpts = []balancer.Option{ + balancer.NoRebalanceLoop(), + balancer.RebalanceRate(1.0), + } + proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + + // Open 12 connections to it. The balancer should distribute the connections + // evenly across 3 SQL pods (i.e. 4 connections each). + dist := map[string]int{} + var conns []*gosql.DB + for i := 0; i < 12; i++ { + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + addr := queryAddr(ctx, t, db) + dist[addr]++ + conns = append(conns, db) + } + + // Validate that connections are balanced evenly (i.e. 12/3 = 4). + for _, d := range dist { + require.Equal(t, 4, d) + } + + // Register the 4th pod. This should emit an event to the pod watcher, which + // triggers rebalancing. Based on the balancer's algorithm, balanced is + // defined as [2, 4] connections. As a result, 2 connections will be moved + // to the new pod. Note that for testing, we set customRebalanceRate to 1.0. + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: addresses[3], + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + }) + + // Wait until two connections have been migrated. + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.ConnMigrationSuccessCount.Count() >= 2 { + return nil + } + return errors.New("waiting for connection migration") + }) + + // Reset distribution and count again. + dist = map[string]int{} + for _, c := range conns { + addr := queryAddr(ctx, t, c) + dist[addr]++ + } + + // Validate distribution. + var counts []int + for _, d := range dist { + counts = append(counts, d) + } + sort.Ints(counts) + require.Equal(t, []int{2, 3, 3, 4}, counts) +} + func TestConnectionMigration(t *testing.T) { defer leaktest.AfterTest(t)() ctx := context.Background() @@ -773,9 +890,7 @@ func TestConnectionMigration(t *testing.T) { defer s.Stopper().Stop(ctx) tenantID := serverutils.TestTenantID() - // TODO(rafi): use ALTER TENANT ALL when available. - _, err := mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES - (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + _, err := mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") require.NoError(t, err) // Start first SQL pod. @@ -806,23 +921,6 @@ func TestConnectionMigration(t *testing.T) { connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) - type queryer interface { - QueryRowContext(context.Context, string, ...interface{}) *gosql.Row - } - // queryAddr queries the SQL node that `db` is connected to for its address. - queryAddr := func(t *testing.T, ctx context.Context, db queryer) string { - t.Helper() - var host, port string - require.NoError(t, db.QueryRowContext(ctx, ` - SELECT - a.value AS "host", b.value AS "port" - FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b - WHERE a.component = 'DB' AND a.field = 'Host' - AND b.component = 'DB' AND b.field = 'Port' - `).Scan(&host, &port)) - return fmt.Sprintf("%s:%s", host, port) - } - // validateMiscMetrics ensures that our invariant of // attempts = success + error_recoverable + error_fatal is valid, and all // other transfer related metrics were incremented as well. @@ -887,7 +985,7 @@ func TestConnectionMigration(t *testing.T) { } t.Run("normal_transfer", func(t *testing.T) { - require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant1.SQLAddr(), queryAddr(tCtx, t, db)) _, err = db.Exec("SET application_name = 'foo'") require.NoError(t, err) @@ -895,7 +993,7 @@ func TestConnectionMigration(t *testing.T) { // Show that we get alternating SQL pods when we transfer. require.NoError(t, f.TransferConnection()) require.Equal(t, int64(1), f.metrics.ConnMigrationSuccessCount.Count()) - require.Equal(t, tenant2.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant2.SQLAddr(), queryAddr(tCtx, t, db)) var name string require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) @@ -906,7 +1004,7 @@ func TestConnectionMigration(t *testing.T) { require.NoError(t, f.TransferConnection()) require.Equal(t, int64(2), f.metrics.ConnMigrationSuccessCount.Count()) - require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant1.SQLAddr(), queryAddr(tCtx, t, db)) require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) require.Equal(t, "bar", name) @@ -929,7 +1027,7 @@ func TestConnectionMigration(t *testing.T) { // This loop will run approximately 5 seconds. var tenant1Addr, tenant2Addr int for i := 0; i < 100; i++ { - addr := queryAddr(t, tCtx, db) + addr := queryAddr(tCtx, t, db) if addr == tenant1.SQLAddr() { tenant1Addr++ } else { @@ -962,7 +1060,7 @@ func TestConnectionMigration(t *testing.T) { // transfers should not close the connection. t.Run("failed_transfers_with_tx", func(t *testing.T) { initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() - initAddr := queryAddr(t, tCtx, db) + initAddr := queryAddr(tCtx, t, db) err = crdb.ExecuteTx(tCtx, db, nil /* txopts */, func(tx *gosql.Tx) error { // Run multiple times to ensure that connection isn't closed. @@ -974,7 +1072,7 @@ func TestConnectionMigration(t *testing.T) { if !assert.Regexp(t, "cannot serialize", err.Error()) { return errors.Wrap(err, "non-serialization error") } - addr := queryAddr(t, tCtx, tx) + addr := queryAddr(tCtx, t, tx) if initAddr != addr { return errors.Newf( "address does not match, expected %s, found %s", @@ -996,7 +1094,7 @@ func TestConnectionMigration(t *testing.T) { // Once the transaction is closed, transfers should work. require.NoError(t, f.TransferConnection()) - require.NotEqual(t, initAddr, queryAddr(t, tCtx, db)) + require.NotEqual(t, initAddr, queryAddr(tCtx, t, db)) require.Nil(t, f.ctx.Err()) require.Equal(t, initSuccessCount+1, f.metrics.ConnMigrationSuccessCount.Count()) require.Equal(t, int64(5), f.metrics.ConnMigrationErrorRecoverableCount.Count()) @@ -1011,7 +1109,7 @@ func TestConnectionMigration(t *testing.T) { t.Run("failed_transfers_with_dial_issues", func(t *testing.T) { initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() initErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() - initAddr := queryAddr(t, tCtx, db) + initAddr := queryAddr(tCtx, t, db) // Set the delay longer than the timeout. lookupAddrDelayDuration = 10 * time.Second @@ -1020,7 +1118,7 @@ func TestConnectionMigration(t *testing.T) { err := f.TransferConnection() require.Error(t, err) require.Regexp(t, "injected delays", err.Error()) - require.Equal(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, initAddr, queryAddr(tCtx, t, db)) require.Nil(t, f.ctx.Err()) require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) @@ -1618,3 +1716,21 @@ func mustGetTestSimpleDirectoryServer( require.True(t, ok) return svr } + +type queryer interface { + QueryRowContext(context.Context, string, ...interface{}) *gosql.Row +} + +// queryAddr queries the SQL node that `db` is connected to for its address. +func queryAddr(ctx context.Context, t *testing.T, db queryer) string { + t.Helper() + var host, port string + require.NoError(t, db.QueryRowContext(ctx, ` + SELECT + a.value AS "host", b.value AS "port" + FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b + WHERE a.component = 'DB' AND a.field = 'Host' + AND b.component = 'DB' AND b.field = 'Port' + `).Scan(&host, &port)) + return fmt.Sprintf("%s:%s", host, port) +} diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel index 7b570b7d2d25..0f3da2561c31 100644 --- a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -63,6 +63,7 @@ go_test( "//pkg/security/securitytest", "//pkg/server", "//pkg/sql", + "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", "//pkg/testutils/testcluster", @@ -70,9 +71,12 @@ go_test( "//pkg/util/log", "//pkg/util/randutil", "//pkg/util/stop", + "//pkg/util/timeutil", + "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//require", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//credentials/insecure", "@org_golang_google_grpc//status", ], ) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go index 175a69d1c868..095b66c529b4 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go +++ b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go @@ -416,7 +416,12 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e continue } - // If caller is watching pods, send to its channel now. + // Update the directory entry for the tenant with the latest + // information about this pod. + d.updateTenantEntry(ctx, resp.Pod) + + // If caller is watching pods, send to its channel now. Only do this + // after updating the tenant entry in the directory. if d.options.podWatcher != nil { select { case d.options.podWatcher <- resp.Pod: @@ -424,10 +429,6 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e return } } - - // Update the directory entry for the tenant with the latest - // information about this pod. - d.updateTenantEntry(ctx, resp.Pod) } }) if err != nil { diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go b/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go index 380df7ca2522..45f929c9f9c2 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go +++ b/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go @@ -21,14 +21,18 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" ) @@ -67,127 +71,109 @@ func TestDirectoryErrors(t *testing.T) { func TestWatchPods(t *testing.T) { defer leaktest.AfterTest(t)() - defer log.ScopeWithoutShowLogs(t).Close(t) - skip.UnderDeadlockWithIssue(t, 71365) + defer log.Scope(t).Close(t) + ctx := context.Background() // Make pod watcher channel. podWatcher := make(chan *tenant.Pod, 1) - // Create the directory. - ctx := context.Background() - tc, dir, tds := newTestDirectoryCache(t, tenant.PodWatcher(podWatcher)) - defer tc.Stopper().Stop(ctx) + // Setup test directory cache and server. + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + dir, tds := setupTestDirectory(t, ctx, stopper, nil /* timeSource */, tenant.PodWatcher(podWatcher)) + + // Wait until the watcher has been established. + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() == 0 { + return errors.New("watchers have not been established yet") + } + return nil + }) tenantID := roachpb.MakeTenantID(20) - require.NoError(t, createTenant(tc, tenantID)) + tds.CreateTenant(tenantID, "my-tenant") + + // Add a new pod to the tenant. + runningPod := &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: "127.0.0.10:10", + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + } + require.True(t, tds.AddPod(tenantID, runningPod)) + pod := <-podWatcher + require.Equal(t, runningPod, pod) - // Call LookupTenantPods to start a new tenant and create an entry. - pods, err := dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should have already been updated. + pods, err := dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr := pods[0].Addr - - // Ensure that correct event was sent to watcher channel. - pod := <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Len(t, pods, 1) + require.Equal(t, runningPod, pods[0]) - // Trigger drain of pod. - tds.Drain() + // Drain the pod. + require.True(t, tds.DrainPod(tenantID, runningPod.Addr)) pod = <-podWatcher require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) + require.Equal(t, runningPod.Addr, pod.Addr) require.Equal(t, tenant.DRAINING, pod.State) require.False(t, pod.StateTimestamp.IsZero()) - // Now shut the tenant directory down. - processes := tds.Get(tenantID) - require.NotNil(t, processes) - require.Len(t, processes, 1) - // Stop the tenant and ensure its IP address is removed from the directory. - for _, process := range processes { - process.Stopper.Stop(ctx) - } - - // Ensure that correct event was sent to watcher channel. - pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.DELETING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) - - // Ensure that all addresses have been cleared from the directory, since - // it should only return RUNNING or DRAINING addresses. - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - return len(tenantPods) == 0 - }, 10*time.Second, 100*time.Millisecond) - - // Resume tenant again by a direct call to the directory server - _, err = tds.EnsurePod(ctx, &tenant.EnsurePodRequest{tenantID.ToUint64()}) + // Directory cache should be updated with the DRAINING pod. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - - // Wait for background watcher to populate the initial pod. - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - if len(tenantPods) != 0 { - addr = tenantPods[0].Addr - return true + require.Len(t, pods, 1) + require.Equal(t, pod, pods[0]) + + // Trigger the directory server to restart. WatchPods should handle + // reconnection properly. + // + // NOTE: We check for the number of listeners before proceeding with the + // pod update (e.g. AddPod) because if we don't do that, there could be a + // situation where AddPod gets called before the watcher gets established, + // which means that the pod update event will never get emitted. This is + // only a test directory server issue due to its simple implementation. One + // way to solve this nicely is to implement checkpointing based on the sent + // updates (just like how Kubernetes bookmarks work). + tds.Stop(ctx) + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() != 0 { + return errors.New("watchers have not been removed yet") } - return false - }, 10*time.Second, 100*time.Millisecond) + return nil + }) + require.NoError(t, tds.Start(ctx)) + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() == 0 { + return errors.New("watchers have not been established yet") + } + return nil + }) - // Ensure that correct event was sent to watcher channel. + // Put the same pod back to running. + require.True(t, tds.AddPod(tenantID, runningPod)) pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Equal(t, runningPod, pod) - // Verify that LookupTenantPods returns the pod's IP address. - pods, err = dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should be updated with the RUNNING pod. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr = pods[0].Addr - - processes = tds.Get(tenantID) - require.NotNil(t, processes) - require.Len(t, processes, 1) - for dirAddr := range processes { - require.Equal(t, addr, dirAddr.String()) - } - - // Stop the tenant and ensure its IP address is removed from the directory. - for _, process := range processes { - process.Stopper.Stop(ctx) - } - - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - return len(tenantPods) == 0 - }, 10*time.Second, 100*time.Millisecond) + require.Len(t, pods, 1) + require.Equal(t, pod, pods[0]) - // Ensure that correct event was sent to watcher channel. + // Delete the pod. + require.True(t, tds.RemovePod(tenantID, runningPod.Addr)) pod = <-podWatcher require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) + require.Equal(t, runningPod.Addr, pod.Addr) require.Equal(t, tenant.DELETING, pod.State) require.False(t, pod.StateTimestamp.IsZero()) - // Verify that a new call to LookupTenantPods will resume again the tenant. - pods, err = dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should have no pods. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr = pods[0].Addr - - // Ensure that correct event was sent to watcher channel. - pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Empty(t, pods) + stopper.Stop(ctx) + stopper.Stop(ctx) } func TestCancelLookups(t *testing.T) { @@ -434,6 +420,40 @@ func startTenant( return &tenantdirsvr.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil } +// setupTestDirectory returns an instance of the directory cache and the +// in-memory test static directory server. Tenants will need to be added/removed +// manually. +func setupTestDirectory( + t *testing.T, + ctx context.Context, + stopper *stop.Stopper, + timeSource timeutil.TimeSource, + opts ...tenant.DirOption, +) (tenant.DirectoryCache, *tenantdirsvr.TestStaticDirectoryServer) { + t.Helper() + + // Start an in-memory static directory server. + directoryServer := tenantdirsvr.NewTestStaticDirectoryServer(stopper, timeSource) + require.NoError(t, directoryServer.Start(ctx)) + + // Dial the test directory server. + conn, err := grpc.DialContext( + ctx, + "", + grpc.WithContextDialer(directoryServer.DialerFunc), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { + _ = conn.Close() // nolint:grpcconnclose + })) + client := tenant.NewDirectoryClient(conn) + directoryCache, err := tenant.NewDirectoryCache(ctx, stopper, client, opts...) + require.NoError(t, err) + + return directoryCache, directoryServer +} + // Setup directory cache that uses a client connected to a test directory server // that manages tenants connected to a backing KV server. func newTestDirectoryCache( @@ -464,7 +484,7 @@ func newTestDirectoryCache( return process, nil } - listenPort, err := net.Listen("tcp", ":0") + listenPort, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = tds.Serve(listenPort) }() diff --git a/pkg/ccl/sqlproxyccl/tenant/entry.go b/pkg/ccl/sqlproxyccl/tenant/entry.go index e68cc0ff8c32..b3e99d246bf7 100644 --- a/pkg/ccl/sqlproxyccl/tenant/entry.go +++ b/pkg/ccl/sqlproxyccl/tenant/entry.go @@ -239,6 +239,7 @@ func (e *tenantEntry) fetchPodsLocked( ctx context.Context, client DirectoryClient, ) (tenantPods []*Pod, err error) { // List the pods for the given tenant. + // // TODO(andyk): This races with the pod watcher, which may receive updates // that are newer than what ListPods returns. This could be fixed by adding // version values to the pods in order to detect races. diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel index c577e2e78f61..f56109730cd3 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "in_mem_listener.go", "test_directory_svr.go", "test_simple_directory_svr.go", + "test_static_directory_svr.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr", visibility = ["//visibility:public"], diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go index cd8c780b2dfa..3f387ec8ec98 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go @@ -48,7 +48,7 @@ type Process struct { // is a possibility that between the two calls, the parent stopper completes a // stop and then the leak detection may find a leaked stopper. func NewSubStopper(parentStopper *stop.Stopper) *stop.Stopper { - mu := &syncutil.Mutex{} + var mu syncutil.Mutex var subStopper *stop.Stopper parentStopper.AddCloser(stop.CloserFn(func() { mu.Lock() diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go index 6961a84407cf..9f0d5565c6aa 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go @@ -44,7 +44,7 @@ type TestSimpleDirectoryServer struct { var _ tenant.DirectoryServer = &TestSimpleDirectoryServer{} // NewTestSimpleDirectoryServer constructs a new simple directory server. -func NewTestSimpleDirectoryServer(podAddr string) (tenant.DirectoryServer, *grpc.Server) { +func NewTestSimpleDirectoryServer(podAddr string) (*TestSimpleDirectoryServer, *grpc.Server) { dir := &TestSimpleDirectoryServer{podAddr: podAddr} dir.mu.deleted = make(map[roachpb.TenantID]struct{}) grpcServer := grpc.NewServer() diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go new file mode 100644 index 000000000000..ed40d2113435 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go @@ -0,0 +1,432 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenantdirsvr + +import ( + "container/list" + "context" + "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/test/bufconn" +) + +// TestStaticDirectoryServer is a directory server that stores a static mapping +// of tenants to their pods. Callers will need to invoke operations on the +// directory server to create/delete/update tenants and pods as necessary. +// Unlike the regular directory server that automatically spins up SQL pods when +// one isn't present, the static directory server does not do that. The caller +// should start the SQL pods up, and register them with the directory server. +type TestStaticDirectoryServer struct { + // rootStopper is used to create sub-stoppers for directory instances. + rootStopper *stop.Stopper + + // timeSource is the source of the time. By default, this will be set to + // timeutil.DefaultTimeSource. + timeSource timeutil.TimeSource + + // process corresponds to fields that need to be updated together whenever + // the test directory is started or stopped. Note that when the test + // directory server gets restarted, we do not reset the tenant data. + process struct { + syncutil.Mutex + + // stopper is used to start async tasks within the directory server. + stopper *stop.Stopper + + // ln corresponds to the in-memory listener for the directory + // server. This will be used to construct a dialer function. + ln *bufconn.Listener + + // grpcServer corresponds to the GRPC server instance for the + // directory server. + grpcServer *grpc.Server + } + + mu struct { + syncutil.Mutex + + // tenants stores a list of pods for every tenant. If a tenant does not + // exist in the map, the tenant is assumed to be non-existent. Since + // this is a test directory server, we do not care about fine-grained + // locking here. + tenants map[roachpb.TenantID][]*tenant.Pod + + // tenantNames stores a list of clusterNames associated with each tenant. + tenantNames map[roachpb.TenantID]string + + // eventListeners stores a list of listeners that are watching changes + // to pods through WatchPods. + eventListeners *list.List + } +} + +var _ tenant.DirectoryServer = &TestStaticDirectoryServer{} + +// NewTestStaticDirectoryServer constructs a new static directory server. +func NewTestStaticDirectoryServer( + stopper *stop.Stopper, timeSource timeutil.TimeSource, +) *TestStaticDirectoryServer { + if timeSource == nil { + timeSource = timeutil.DefaultTimeSource{} + } + dir := &TestStaticDirectoryServer{rootStopper: stopper, timeSource: timeSource} + dir.mu.tenants = make(map[roachpb.TenantID][]*tenant.Pod) + dir.mu.tenantNames = make(map[roachpb.TenantID]string) + dir.mu.eventListeners = list.New() + return dir +} + +// ListPods returns a list with all SQL pods associated with the given tenant. +// If the tenant does not exists, no pods will be returned. +// +// ListPods implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) ListPods( + ctx context.Context, req *tenant.ListPodsRequest, +) (*tenant.ListPodsResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[roachpb.MakeTenantID(req.TenantID)] + if !ok { + return &tenant.ListPodsResponse{}, nil + } + + // Return a copy of all the pods to avoid any race issues. + tenantPods := make([]*tenant.Pod, len(pods)) + for i, pod := range pods { + copyPod := *pod + tenantPods[i] = ©Pod + } + return &tenant.ListPodsResponse{Pods: tenantPods}, nil +} + +// WatchPods allows callers to monitor for pod update events. +// +// WatchPods implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) WatchPods( + req *tenant.WatchPodsRequest, server tenant.Directory_WatchPodsServer, +) error { + d.process.Lock() + stopper := d.process.stopper + d.process.Unlock() + + // This cannot happen unless WatchPods was called directly, which we + // shouldn't since it is meant to be called through a GRPC client. + if stopper == nil { + return status.Errorf(codes.FailedPrecondition, "directory server has not been started") + } + + addListener := func(ch chan *tenant.WatchPodsResponse) *list.Element { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.PushBack(ch) + } + removeListener := func(e *list.Element) chan *tenant.WatchPodsResponse { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.Remove(e).(chan *tenant.WatchPodsResponse) + } + + // Construct the channel with a small buffer to allow for a burst of + // notifications, and a slow receiver. + c := make(chan *tenant.WatchPodsResponse, 10) + chElement := addListener(c) + + return stopper.RunTask( + context.Background(), + "watch-pods-server", + func(ctx context.Context) { + defer func() { + if ch := removeListener(chElement); ch != nil { + close(ch) + } + }() + + for watch := true; watch; { + select { + case e, ok := <-c: + // Channel was closed. + if !ok { + watch = false + break + } + if err := server.Send(e); err != nil { + watch = false + } + case <-stopper.ShouldQuiesce(): + watch = false + } + } + }, + ) +} + +// EnsurePod returns an empty response if a tenant with the given tenant ID +// exists, and there is at least one SQL pod. If there are no SQL pods, a GRPC +// FailedPrecondition error will be returned. Similarly, if the tenant does not +// exists, a GRPC NotFound error will be returned. This would mimic the behavior +// that we have in the actual tenant directory. +// +// EnsurePod implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) EnsurePod( + ctx context.Context, req *tenant.EnsurePodRequest, +) (*tenant.EnsurePodResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[roachpb.MakeTenantID(req.TenantID)] + if !ok { + return nil, status.Errorf(codes.NotFound, "tenant does not exist") + } + if len(pods) == 0 { + return nil, status.Errorf(codes.FailedPrecondition, "tenant has no pods") + } + return &tenant.EnsurePodResponse{}, nil +} + +// GetTenant returns tenant metadata associated with the given tenant ID. If the +// tenant isn't in the directory server, a GRPC NotFound error will be returned. +// +// GetTenant implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) GetTenant( + ctx context.Context, req *tenant.GetTenantRequest, +) (*tenant.GetTenantResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + tenantID := roachpb.MakeTenantID(req.TenantID) + if _, ok := d.mu.tenants[tenantID]; !ok { + return nil, status.Errorf(codes.NotFound, "tenant does not exist") + } + return &tenant.GetTenantResponse{ClusterName: d.mu.tenantNames[tenantID]}, nil +} + +// CreateTenant creates a tenant with the given tenant ID in the directory +// server. If the tenant already exists, this is a no-op. +func (d *TestStaticDirectoryServer) CreateTenant(tenantID roachpb.TenantID, clusterName string) { + d.mu.Lock() + defer d.mu.Unlock() + + if _, ok := d.mu.tenants[tenantID]; ok { + return + } + d.mu.tenants[tenantID] = make([]*tenant.Pod, 0) + d.mu.tenantNames[tenantID] = clusterName +} + +// DeleteTenant ensures that the tenant with the given tenant ID has been +// removed from the directory server. Doing this would return a NotFound error +// for certain directory server endpoints. This also changes the behavior of +// ListPods so no pods are returned for the given tenant. +func (d *TestStaticDirectoryServer) DeleteTenant(tenantID roachpb.TenantID) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[tenantID] + if !ok { + return + } + for _, pod := range pods { + // Update pod to DELETING, and emit event. + pod.State = tenant.DELETING + pod.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(pod) + } + delete(d.mu.tenants, tenantID) + delete(d.mu.tenantNames, tenantID) +} + +// AddPod adds the pod to the given tenant pod's list. If a SQL pod with the +// same address already exists, this updates the existing pod. AddPod returns +// true if the operation was successful, and false otherwise. +// +// NOTE: pod has to be fully populated, and pod.TenantID should be the same as +// the given tenantID. +func (d *TestStaticDirectoryServer) AddPod(tenantID roachpb.TenantID, pod *tenant.Pod) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // Emit an event that the pod has been created. + d.notifyPodUpdateLocked(pod) + + // Check if the pod exists. This would handle pods transitioning from + // DRAINING to RUNNING. + for i, existing := range pods { + if existing.Addr == pod.Addr { + d.mu.tenants[tenantID][i] = pod + return true + } + } + + // A new pod has been added. + d.mu.tenants[tenantID] = append(d.mu.tenants[tenantID], pod) + return true +} + +// DrainPod puts the tenant associated SQL pod with the given address into the +// DRAINING state. DrainPod returns true if the operation was successful, and +// false otherwise. +func (d *TestStaticDirectoryServer) DrainPod(tenantID roachpb.TenantID, podAddr string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // If the pod exists, update its state to DRAINING. + for _, existing := range pods { + if existing.Addr == podAddr { + existing.State = tenant.DRAINING + existing.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(existing) + return true + } + } + return false +} + +// RemovePod deletes the SQL pod with the given address from the associated +// tenant. RemovePod returns true if the operation was successful, and false +// otherwise. +func (d *TestStaticDirectoryServer) RemovePod(tenantID roachpb.TenantID, podAddr string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // If the pod exists, remove it. + for i, existing := range pods { + if existing.Addr == podAddr { + // Remove pod. + copy(d.mu.tenants[tenantID][i:], d.mu.tenants[tenantID][i+1:]) + d.mu.tenants[tenantID] = d.mu.tenants[tenantID][:len(d.mu.tenants[tenantID])-1] + + // Update pod to DELETING, and emit event. + existing.State = tenant.DELETING + existing.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(existing) + return true + } + } + return false +} + +// Start starts the test directory server using an in-memory listener. This +// returns an error if the server cannot be started. If the sevrer has already +// been started, this is a no-op. +func (d *TestStaticDirectoryServer) Start(ctx context.Context) error { + d.process.Lock() + defer d.process.Unlock() + + // Server has already been started. + if d.process.ln != nil { + return nil + } + + stopper := NewSubStopper(d.rootStopper) + grpcServer := grpc.NewServer() + tenant.RegisterDirectoryServer(grpcServer, d) + ln, err := ListenAndServeInMemGRPC(ctx, stopper, grpcServer) + if err != nil { + return err + } + + // Update instance fields together. + d.process.stopper = stopper + d.process.ln = ln + d.process.grpcServer = grpcServer + return nil +} + +// Stop stops the test directory server instance. If the server has already +// been stopped, this is a no-op. +func (d *TestStaticDirectoryServer) Stop(ctx context.Context) { + d.process.Lock() + defer d.process.Unlock() + + stopper, ln, grpcServer := d.process.stopper, d.process.ln, d.process.grpcServer + if ln == nil { + return + } + + d.process.stopper = nil + d.process.ln = nil + d.process.grpcServer = nil + + // Close listener first, followed by GRPC server, and stopper. + _ = ln.Close() + grpcServer.Stop() + stopper.Stop(ctx) +} + +// DialerFunc corresponds to the dialer function used to dial the test directory +// server. We do this because the test directory server runs in-memory, and +// does not bind to a physical network address. This will be used within +// grpc.WithContextDialer. +func (d *TestStaticDirectoryServer) DialerFunc(ctx context.Context, addr string) (net.Conn, error) { + d.process.Lock() + listener := d.process.ln + d.process.Unlock() + + if listener == nil { + return nil, errors.New("directory server has not been started") + } + return listener.DialContext(ctx) +} + +// WatchListenersCount returns the number of active listeners for pod update +// events. +func (d *TestStaticDirectoryServer) WatchListenersCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.Len() +} + +// notifyPodUpdateLocked sends a pod update event to all WatchPods listeners. +func (d *TestStaticDirectoryServer) notifyPodUpdateLocked(pod *tenant.Pod) { + // Make a copy of the pod to prevent race issues. + copyPod := *pod + res := &tenant.WatchPodsResponse{Pod: ©Pod} + + for e := d.mu.eventListeners.Front(); e != nil; { + select { + case e.Value.(chan *tenant.WatchPodsResponse) <- res: + e = e.Next() + default: + // The receiver is unable to consume fast enough. Close the channel + // and remove it from the list. + eToClose := e + e = e.Next() + ch := d.mu.eventListeners.Remove(eToClose) + close(ch.(chan *tenant.WatchPodsResponse)) + } + } +} diff --git a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go index 7946c7e96d48..a46962bf1e0f 100644 --- a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go +++ b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go @@ -36,9 +36,7 @@ func TestShowTransferState(t *testing.T) { _, err := tenantDB.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") require.NoError(t, err) - // TODO(rafi): use ALTER TENANT ALL when available. - _, err = mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES - (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + _, err = mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") require.NoError(t, err) t.Run("without_transfer_key", func(t *testing.T) { diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index 53e654b99d6f..38b7fbea166e 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -866,6 +866,7 @@ func TestLint(t *testing.T) { ":!util/tracing/*_test.go", ":!ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go", ":!ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go", + ":!ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go", ) if err != nil { t.Fatal(err) From 4d886128c423b0c0ca63a271f0d18c699ad1f1c7 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 11 May 2022 03:23:59 +0000 Subject: [PATCH 2/2] ccl/sqlproxyccl: rate limit the number of rebalances per tenant This commit rate limits the number of rebalances per tenant to once every 15 seconds (i.e. 1/2 of the rebalance loop interval). The main purpose of this is to prevent a burst of pod events for the same tenant causing multiple rebalances, which may move a lot of connections around. Release note: None --- pkg/ccl/sqlproxyccl/balancer/balancer.go | 148 ++++++++++++------ pkg/ccl/sqlproxyccl/balancer/balancer_test.go | 104 ++++++++++++ 2 files changed, 200 insertions(+), 52 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer.go b/pkg/ccl/sqlproxyccl/balancer/balancer.go index a67b4ffea1b8..f000bc20d440 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer.go @@ -38,6 +38,12 @@ const ( // DRAINING state before the proxy starts moving connections away from it. minDrainPeriod = 1 * time.Minute + // defaultRebalanceDelay is the minimum amount of time that must elapse + // between rebalance operations. This was deliberately chosen to be half of + // rebalanceInterval, and is mainly used to rate limit effects due to events + // from the pod watcher. + defaultRebalanceDelay = 15 * time.Second + // rebalancePercentDeviation defines the percentage threshold that the // current number of assignments can deviate away from the mean. Having a // 15% "deadzone" reduces frequent transfers especially when load is @@ -79,6 +85,7 @@ type balancerOptions struct { noRebalanceLoop bool timeSource timeutil.TimeSource rebalanceRate float32 + rebalanceDelay time.Duration } // Option defines an option that can be passed to NewBalancer in order to @@ -110,14 +117,25 @@ func TimeSource(ts timeutil.TimeSource) Option { } } -// RebalanceRate defines the rate of rebalancing across pods. Set to -1 to -// disable rebalancing (i.e. connection transfers). +// RebalanceRate defines the rate of rebalancing across pods. This must be +// between 0 and 1 inclusive. 0 means no rebalancing will occur. func RebalanceRate(rate float32) Option { return func(opts *balancerOptions) { opts.rebalanceRate = rate } } +// RebalanceDelay specifies the minimum amount of time that must elapse between +// attempts to rebalance a given tenant. This delay has the effect of throttling +// RebalanceTenant calls to avoid constantly moving connections around. +// +// RebalanceDelay defaults to defaultRebalanceDelay. Use -1 to never throttle. +func RebalanceDelay(delay time.Duration) Option { + return func(opts *balancerOptions) { + opts.rebalanceDelay = delay + } +} + // Balancer handles load balancing of SQL connections within the proxy. // All methods on the Balancer instance are thread-safe. type Balancer struct { @@ -152,6 +170,19 @@ type Balancer struct { // rebalanceRate represents the rate of rebalancing connections. rebalanceRate float32 + + // rebalanceDelay is the minimum amount of time that must elapse between + // attempts to rebalance a given tenant. Defaults to defaultRebalanceDelay. + rebalanceDelay time.Duration + + // lastRebalance is the last time the tenants are rebalanced. This is used + // to rate limit the number of rebalances per tenant. Synchronization is + // needed since rebalance operations can be triggered by the rebalance loop, + // or the pod watcher. + lastRebalance struct { + syncutil.Mutex + tenants map[roachpb.TenantID]time.Time + } } // NewBalancer constructs a new Balancer instance that is responsible for @@ -164,22 +195,15 @@ func NewBalancer( opts ...Option, ) (*Balancer, error) { // Handle options. - options := &balancerOptions{} + options := &balancerOptions{ + maxConcurrentRebalances: defaultMaxConcurrentRebalances, + timeSource: timeutil.DefaultTimeSource{}, + rebalanceRate: defaultRebalanceRate, + rebalanceDelay: defaultRebalanceDelay, + } for _, opt := range opts { opt(options) } - if options.maxConcurrentRebalances == 0 { - options.maxConcurrentRebalances = defaultMaxConcurrentRebalances - } - if options.timeSource == nil { - options.timeSource = timeutil.DefaultTimeSource{} - } - if options.rebalanceRate == 0 { - options.rebalanceRate = defaultRebalanceRate - } - if options.rebalanceRate == -1 { - options.rebalanceRate = 0 - } // Ensure that ctx gets cancelled on stopper's quiescing. ctx, _ = stopper.WithCancelOnQuiesce(ctx) @@ -197,7 +221,10 @@ func NewBalancer( processSem: semaphore.New(options.maxConcurrentRebalances), timeSource: options.timeSource, rebalanceRate: options.rebalanceRate, + rebalanceDelay: options.rebalanceDelay, } + b.lastRebalance.tenants = make(map[roachpb.TenantID]time.Time) + b.connTracker, err = NewConnTracker(ctx, b.stopper, b.timeSource) if err != nil { return nil, err @@ -218,6 +245,46 @@ func NewBalancer( return b, nil } +// RebalanceTenant rebalances connections to the given tenant. If no RUNNING +// pod exists for the given tenant, or the tenant has been recently rebalanced, +// this is a no-op. +func (b *Balancer) RebalanceTenant(ctx context.Context, tenantID roachpb.TenantID) { + // If rebalanced recently, no-op. + if !b.canRebalanceTenant(tenantID) { + return + } + + tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) + if err != nil { + log.Errorf(ctx, "could not rebalance tenant %s: %v", tenantID, err.Error()) + return + } + + // Construct a map so we could easily retrieve the pod by address. + podMap := make(map[string]*tenant.Pod) + var hasRunningPod bool + for _, pod := range tenantPods { + podMap[pod.Addr] = pod + + if pod.State == tenant.RUNNING { + hasRunningPod = true + } + } + + // Only attempt to rebalance if we have a RUNNING pod. In theory, this + // case would happen if we're scaling down from 1 to 0, which in that + // case, we can't transfer connections anywhere. Practically, we will + // never scale a tenant from 1 to 0 if there are still active + // connections, so this case should not occur. + if !hasRunningPod { + return + } + + activeList, idleList := b.connTracker.listAssignments(tenantID) + b.rebalancePartition(podMap, activeList) + b.rebalancePartition(podMap, idleList) +} + // SelectTenantPod selects a tenant pod from the given list based on a weighted // CPU load algorithm. It is expected that all pods within the list belongs to // the same tenant. If no pods are available, this returns ErrNoAvailablePods. @@ -318,6 +385,20 @@ func (b *Balancer) rebalanceLoop(ctx context.Context) { } } +// canRebalanceTenant returns true if it has been at least `rebalanceDelay` +// since the last time the given tenant was rebalanced, or false otherwise. +func (b *Balancer) canRebalanceTenant(tenantID roachpb.TenantID) bool { + b.lastRebalance.Lock() + defer b.lastRebalance.Unlock() + + now := b.timeSource.Now() + if now.Sub(b.lastRebalance.tenants[tenantID]) < b.rebalanceDelay { + return false + } + b.lastRebalance.tenants[tenantID] = now + return true +} + // rebalance attempts to rebalance connections for all tenants within the proxy. func (b *Balancer) rebalance(ctx context.Context) { // getTenantIDs ensures that tenants will have at least one connection. @@ -327,43 +408,6 @@ func (b *Balancer) rebalance(ctx context.Context) { } } -// RebalanceTenant rebalances connections for the given tenant. If no RUNNING -// pod exists for the given tenant, this is a no-op. -// -// TODO(jaylim-crl): Rate limit the number of rebalances per tenant for requests -// coming from the pod watcher. -func (b *Balancer) RebalanceTenant(ctx context.Context, tenantID roachpb.TenantID) { - tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) - if err != nil { - log.Errorf(ctx, "could not rebalance tenant %s: %v", tenantID, err.Error()) - return - } - - // Construct a map so we could easily retrieve the pod by address. - podMap := make(map[string]*tenant.Pod) - var hasRunningPod bool - for _, pod := range tenantPods { - podMap[pod.Addr] = pod - - if pod.State == tenant.RUNNING { - hasRunningPod = true - } - } - - // Only attempt to rebalance if we have a RUNNING pod. In theory, this - // case would happen if we're scaling down from 1 to 0, which in that - // case, we can't transfer connections anywhere. Practically, we will - // never scale a tenant from 1 to 0 if there are still active - // connections, so this case should not occur. - if !hasRunningPod { - return - } - - activeList, idleList := b.connTracker.listAssignments(tenantID) - b.rebalancePartition(podMap, activeList) - b.rebalancePartition(podMap, idleList) -} - // rebalancePartition rebalances the given assignments partition. func (b *Balancer) rebalancePartition( pods map[string]*tenant.Pod, assignments []*ServerAssignment, diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go index 6395823f60a6..fa59896d3f37 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go @@ -365,6 +365,7 @@ func TestRebalancer_rebalance(t *testing.T) { directoryCache, NoRebalanceLoop(), TimeSource(timeSource), + RebalanceDelay(-1), ) require.NoError(t, err) @@ -715,6 +716,109 @@ func TestRebalancer_rebalance(t *testing.T) { } } +func TestBalancer_RebalanceTenant_WithDefaultDelay(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + // Use a custom time source for testing. + t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) + timeSource := timeutil.NewManualTime(t0) + + metrics := NewMetrics() + directoryCache := newTestDirectoryCache() + + b, err := NewBalancer( + ctx, + stopper, + metrics, + directoryCache, + NoRebalanceLoop(), + TimeSource(timeSource), + ) + require.NoError(t, err) + + tenantID := roachpb.MakeTenantID(10) + pods := []*tenant.Pod{ + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:80", State: tenant.DRAINING}, + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:81", State: tenant.RUNNING}, + } + for _, pod := range pods { + require.True(t, directoryCache.upsertPod(pod)) + } + + // Create 100 active connections, all to the draining pod. + const numConns = 100 + var mu syncutil.Mutex + assignments := make([]*ServerAssignment, numConns) + makeTestConnHandle := func(idx int) *testConnHandle { + var handle *testConnHandle + handle = &testConnHandle{ + onTransferConnection: func() error { + mu.Lock() + defer mu.Unlock() + assignments[idx].Close() + assignments[idx] = NewServerAssignment( + tenantID, b.connTracker, handle, pods[1].Addr, + ) + return nil + }, + } + return handle + } + var handles []ConnectionHandle + for i := 0; i < numConns; i++ { + handle := makeTestConnHandle(i) + handles = append(handles, handle) + assignments[i] = NewServerAssignment( + tenantID, b.connTracker, handle, pods[0].Addr, + ) + } + + waitFor := func(numTransfers int) { + testutils.SucceedsSoon(t, func() error { + count := 0 + for i := 0; i < 100; i++ { + count += handles[i].(*testConnHandle).transferConnectionCount() + } + if count != numTransfers { + return errors.Newf("require %d, but got %v", numTransfers, count) + } + return nil + }) + } + + // Attempt the rebalance, and wait until 50 were moved + // (i.e. 100 * defaultRebalanceRate). + b.RebalanceTenant(ctx, tenantID) + waitFor(50) + + // Run the rebalance again. + b.RebalanceTenant(ctx, tenantID) + + // Queue should be empty, and no additional connections should be moved. + b.queue.mu.Lock() + queueLen := b.queue.queue.Len() + b.queue.mu.Unlock() + require.Equal(t, 0, queueLen) + waitFor(50) + + // Advance time, rebalance, and wait until 75 (i.e. 50 + 25) connections + // get moved. + timeSource.Advance(defaultRebalanceDelay) + b.RebalanceTenant(ctx, tenantID) + waitFor(75) + + // Advance time, rebalance, and wait until 88 (i.e. 75 + 13) connections + // get moved. + timeSource.Advance(defaultRebalanceDelay) + b.RebalanceTenant(ctx, tenantID) + waitFor(88) +} + func TestEnqueueRebalanceRequests(t *testing.T) { defer leaktest.AfterTest(t)()