diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer.go b/pkg/ccl/sqlproxyccl/balancer/balancer.go index b15cfb3d8d64..f000bc20d440 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer.go @@ -38,6 +38,12 @@ const ( // DRAINING state before the proxy starts moving connections away from it. minDrainPeriod = 1 * time.Minute + // defaultRebalanceDelay is the minimum amount of time that must elapse + // between rebalance operations. This was deliberately chosen to be half of + // rebalanceInterval, and is mainly used to rate limit effects due to events + // from the pod watcher. + defaultRebalanceDelay = 15 * time.Second + // rebalancePercentDeviation defines the percentage threshold that the // current number of assignments can deviate away from the mean. Having a // 15% "deadzone" reduces frequent transfers especially when load is @@ -50,15 +56,15 @@ const ( // NOTE: This must be between 0 and 1 inclusive. rebalancePercentDeviation = 0.15 - // rebalanceRate defines the rate of rebalancing assignments across SQL - // pods. This rate applies to both RUNNING and DRAINING pods. For example, - // consider the case where the rate is 0.50; if we have decided that we need - // to move 15 assignments away from a particular pod, only 7 pods will be - // moved at a time. + // defaultRebalanceRate defines the rate of rebalancing assignments across + // SQL pods. This rate applies to both RUNNING and DRAINING pods. For + // example, consider the case where the rate is 0.50; if we have decided + // that we need to move 15 assignments away from a particular pod, only 7 + // pods will be moved at a time. // // NOTE: This must be between 0 and 1 inclusive. 0 means no rebalancing // will occur. - rebalanceRate = 0.50 + defaultRebalanceRate = 0.50 // defaultMaxConcurrentRebalances represents the maximum number of // concurrent rebalance requests that are being processed. This effectively @@ -78,6 +84,8 @@ type balancerOptions struct { maxConcurrentRebalances int noRebalanceLoop bool timeSource timeutil.TimeSource + rebalanceRate float32 + rebalanceDelay time.Duration } // Option defines an option that can be passed to NewBalancer in order to @@ -109,6 +117,25 @@ func TimeSource(ts timeutil.TimeSource) Option { } } +// RebalanceRate defines the rate of rebalancing across pods. This must be +// between 0 and 1 inclusive. 0 means no rebalancing will occur. +func RebalanceRate(rate float32) Option { + return func(opts *balancerOptions) { + opts.rebalanceRate = rate + } +} + +// RebalanceDelay specifies the minimum amount of time that must elapse between +// attempts to rebalance a given tenant. This delay has the effect of throttling +// RebalanceTenant calls to avoid constantly moving connections around. +// +// RebalanceDelay defaults to defaultRebalanceDelay. Use -1 to never throttle. +func RebalanceDelay(delay time.Duration) Option { + return func(opts *balancerOptions) { + opts.rebalanceDelay = delay + } +} + // Balancer handles load balancing of SQL connections within the proxy. // All methods on the Balancer instance are thread-safe. type Balancer struct { @@ -140,6 +167,22 @@ type Balancer struct { // timeutil.DefaultTimeSource. Override with the TimeSource() option when // calling NewBalancer. timeSource timeutil.TimeSource + + // rebalanceRate represents the rate of rebalancing connections. + rebalanceRate float32 + + // rebalanceDelay is the minimum amount of time that must elapse between + // attempts to rebalance a given tenant. Defaults to defaultRebalanceDelay. + rebalanceDelay time.Duration + + // lastRebalance is the last time the tenants are rebalanced. This is used + // to rate limit the number of rebalances per tenant. Synchronization is + // needed since rebalance operations can be triggered by the rebalance loop, + // or the pod watcher. + lastRebalance struct { + syncutil.Mutex + tenants map[roachpb.TenantID]time.Time + } } // NewBalancer constructs a new Balancer instance that is responsible for @@ -152,16 +195,15 @@ func NewBalancer( opts ...Option, ) (*Balancer, error) { // Handle options. - options := &balancerOptions{} + options := &balancerOptions{ + maxConcurrentRebalances: defaultMaxConcurrentRebalances, + timeSource: timeutil.DefaultTimeSource{}, + rebalanceRate: defaultRebalanceRate, + rebalanceDelay: defaultRebalanceDelay, + } for _, opt := range opts { opt(options) } - if options.maxConcurrentRebalances == 0 { - options.maxConcurrentRebalances = defaultMaxConcurrentRebalances - } - if options.timeSource == nil { - options.timeSource = timeutil.DefaultTimeSource{} - } // Ensure that ctx gets cancelled on stopper's quiescing. ctx, _ = stopper.WithCancelOnQuiesce(ctx) @@ -178,7 +220,11 @@ func NewBalancer( queue: q, processSem: semaphore.New(options.maxConcurrentRebalances), timeSource: options.timeSource, + rebalanceRate: options.rebalanceRate, + rebalanceDelay: options.rebalanceDelay, } + b.lastRebalance.tenants = make(map[roachpb.TenantID]time.Time) + b.connTracker, err = NewConnTracker(ctx, b.stopper, b.timeSource) if err != nil { return nil, err @@ -199,6 +245,46 @@ func NewBalancer( return b, nil } +// RebalanceTenant rebalances connections to the given tenant. If no RUNNING +// pod exists for the given tenant, or the tenant has been recently rebalanced, +// this is a no-op. +func (b *Balancer) RebalanceTenant(ctx context.Context, tenantID roachpb.TenantID) { + // If rebalanced recently, no-op. + if !b.canRebalanceTenant(tenantID) { + return + } + + tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) + if err != nil { + log.Errorf(ctx, "could not rebalance tenant %s: %v", tenantID, err.Error()) + return + } + + // Construct a map so we could easily retrieve the pod by address. + podMap := make(map[string]*tenant.Pod) + var hasRunningPod bool + for _, pod := range tenantPods { + podMap[pod.Addr] = pod + + if pod.State == tenant.RUNNING { + hasRunningPod = true + } + } + + // Only attempt to rebalance if we have a RUNNING pod. In theory, this + // case would happen if we're scaling down from 1 to 0, which in that + // case, we can't transfer connections anywhere. Practically, we will + // never scale a tenant from 1 to 0 if there are still active + // connections, so this case should not occur. + if !hasRunningPod { + return + } + + activeList, idleList := b.connTracker.listAssignments(tenantID) + b.rebalancePartition(podMap, activeList) + b.rebalancePartition(podMap, idleList) +} + // SelectTenantPod selects a tenant pod from the given list based on a weighted // CPU load algorithm. It is expected that all pods within the list belongs to // the same tenant. If no pods are available, this returns ErrNoAvailablePods. @@ -299,49 +385,26 @@ func (b *Balancer) rebalanceLoop(ctx context.Context) { } } +// canRebalanceTenant returns true if it has been at least `rebalanceDelay` +// since the last time the given tenant was rebalanced, or false otherwise. +func (b *Balancer) canRebalanceTenant(tenantID roachpb.TenantID) bool { + b.lastRebalance.Lock() + defer b.lastRebalance.Unlock() + + now := b.timeSource.Now() + if now.Sub(b.lastRebalance.tenants[tenantID]) < b.rebalanceDelay { + return false + } + b.lastRebalance.tenants[tenantID] = now + return true +} + // rebalance attempts to rebalance connections for all tenants within the proxy. -// -// TODO(jaylim-crl): Update this to support rebalancing a single tenant. That -// way, the pod watcher could call this to rebalance a single tenant. We may -// also want to rate limit the number of rebalances per tenant for requests -// coming from the pod watcher. func (b *Balancer) rebalance(ctx context.Context) { // getTenantIDs ensures that tenants will have at least one connection. tenantIDs := b.connTracker.getTenantIDs() - for _, tenantID := range tenantIDs { - tenantPods, err := b.directoryCache.TryLookupTenantPods(ctx, tenantID) - if err != nil { - // This case shouldn't really occur unless there's a bug in the - // directory server (e.g. deleted pod events, but the pod is still - // alive). - log.Errorf(ctx, "could not lookup pods for tenant %s: %v", tenantID, err.Error()) - continue - } - - // Construct a map so we could easily retrieve the pod by address. - podMap := make(map[string]*tenant.Pod) - var hasRunningPod bool - for _, pod := range tenantPods { - podMap[pod.Addr] = pod - - if pod.State == tenant.RUNNING { - hasRunningPod = true - } - } - - // Only attempt to rebalance if we have a RUNNING pod. In theory, this - // case would happen if we're scaling down from 1 to 0, which in that - // case, we can't transfer connections anywhere. Practically, we will - // never scale a tenant from 1 to 0 if there are still active - // connections, so this case should not occur. - if !hasRunningPod { - continue - } - - activeList, idleList := b.connTracker.listAssignments(tenantID) - b.rebalancePartition(podMap, activeList) - b.rebalancePartition(podMap, idleList) + b.RebalanceTenant(ctx, tenantID) } } @@ -349,8 +412,8 @@ func (b *Balancer) rebalance(ctx context.Context) { func (b *Balancer) rebalancePartition( pods map[string]*tenant.Pod, assignments []*ServerAssignment, ) { - // Nothing to do here. - if len(pods) == 0 || len(assignments) == 0 { + // Nothing to do here if there are no assignments, or only one pod. + if len(pods) <= 1 || len(assignments) == 0 { return } @@ -371,7 +434,7 @@ func (b *Balancer) rebalancePartition( // // NOTE: Elements in the list may be shuffled around once this method returns. func (b *Balancer) enqueueRebalanceRequests(list []*ServerAssignment) { - toMoveCount := int(math.Ceil(float64(len(list)) * float64(rebalanceRate))) + toMoveCount := int(math.Ceil(float64(len(list)) * float64(b.rebalanceRate))) partition, _ := partitionNRandom(list, toMoveCount) for _, a := range partition { b.queue.enqueue(&rebalanceRequest{ diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go index 6395823f60a6..fa59896d3f37 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go @@ -365,6 +365,7 @@ func TestRebalancer_rebalance(t *testing.T) { directoryCache, NoRebalanceLoop(), TimeSource(timeSource), + RebalanceDelay(-1), ) require.NoError(t, err) @@ -715,6 +716,109 @@ func TestRebalancer_rebalance(t *testing.T) { } } +func TestBalancer_RebalanceTenant_WithDefaultDelay(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + // Use a custom time source for testing. + t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) + timeSource := timeutil.NewManualTime(t0) + + metrics := NewMetrics() + directoryCache := newTestDirectoryCache() + + b, err := NewBalancer( + ctx, + stopper, + metrics, + directoryCache, + NoRebalanceLoop(), + TimeSource(timeSource), + ) + require.NoError(t, err) + + tenantID := roachpb.MakeTenantID(10) + pods := []*tenant.Pod{ + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:80", State: tenant.DRAINING}, + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:81", State: tenant.RUNNING}, + } + for _, pod := range pods { + require.True(t, directoryCache.upsertPod(pod)) + } + + // Create 100 active connections, all to the draining pod. + const numConns = 100 + var mu syncutil.Mutex + assignments := make([]*ServerAssignment, numConns) + makeTestConnHandle := func(idx int) *testConnHandle { + var handle *testConnHandle + handle = &testConnHandle{ + onTransferConnection: func() error { + mu.Lock() + defer mu.Unlock() + assignments[idx].Close() + assignments[idx] = NewServerAssignment( + tenantID, b.connTracker, handle, pods[1].Addr, + ) + return nil + }, + } + return handle + } + var handles []ConnectionHandle + for i := 0; i < numConns; i++ { + handle := makeTestConnHandle(i) + handles = append(handles, handle) + assignments[i] = NewServerAssignment( + tenantID, b.connTracker, handle, pods[0].Addr, + ) + } + + waitFor := func(numTransfers int) { + testutils.SucceedsSoon(t, func() error { + count := 0 + for i := 0; i < 100; i++ { + count += handles[i].(*testConnHandle).transferConnectionCount() + } + if count != numTransfers { + return errors.Newf("require %d, but got %v", numTransfers, count) + } + return nil + }) + } + + // Attempt the rebalance, and wait until 50 were moved + // (i.e. 100 * defaultRebalanceRate). + b.RebalanceTenant(ctx, tenantID) + waitFor(50) + + // Run the rebalance again. + b.RebalanceTenant(ctx, tenantID) + + // Queue should be empty, and no additional connections should be moved. + b.queue.mu.Lock() + queueLen := b.queue.queue.Len() + b.queue.mu.Unlock() + require.Equal(t, 0, queueLen) + waitFor(50) + + // Advance time, rebalance, and wait until 75 (i.e. 50 + 25) connections + // get moved. + timeSource.Advance(defaultRebalanceDelay) + b.RebalanceTenant(ctx, tenantID) + waitFor(75) + + // Advance time, rebalance, and wait until 88 (i.e. 75 + 13) connections + // get moved. + timeSource.Advance(defaultRebalanceDelay) + b.RebalanceTenant(ctx, tenantID) + waitFor(88) +} + func TestEnqueueRebalanceRequests(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 9c85db56e4ce..c438da17b3bd 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -105,9 +105,13 @@ type ProxyOptions struct { // proxy. dirOpts []tenant.DirOption - // directoryServer represents the in-memory directory server that is - // created whenever a routing rule is used. + // directoryServer represents the directory server that will be used + // by the proxy handler. If unset, initializing the proxy handler will + // create one, and populate this value. directoryServer tenant.DirectoryServer + + // balancerOpts is used to customize the balancer created by the proxy. + balancerOpts []balancer.Option } } @@ -184,8 +188,31 @@ func newProxyHandler( throttler.WithBaseDelay(handler.ThrottleBaseDelay), ) + // TODO(jaylim-crl): Clean up how we start different types of directory + // servers. We could have two options: remote or local. Local servers that + // are using the in-memory implementation should only be used for testing + // only. For production use-cases, we will need to update that to listen + // on an actual network address, but there are no plans to support that at + // the moment. var conn *grpc.ClientConn - if handler.DirectoryAddr != "" { + if handler.testingKnobs.directoryServer != nil { + // TODO(jaylim-crl): For now, only support the static version. We should + // make this part of a LocalDirectoryServer interface for us to grab the + // in-memory listener. + directoryServer, ok := handler.testingKnobs.directoryServer.(*tenantdirsvr.TestStaticDirectoryServer) + if !ok { + return nil, errors.New("unsupported test directory server") + } + conn, err = grpc.DialContext( + ctx, + "", + grpc.WithContextDialer(directoryServer.DialerFunc), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, err + } + } else if handler.DirectoryAddr != "" { conn, err = grpc.Dial( handler.DirectoryAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -235,7 +262,11 @@ func newProxyHandler( balancerMetrics := balancer.NewMetrics() registry.AddMetricStruct(balancerMetrics) - handler.balancer, err = balancer.NewBalancer(ctx, stopper, balancerMetrics, handler.directoryCache) + var balancerOpts []balancer.Option + if handler.testingKnobs.balancerOpts != nil { + balancerOpts = append(balancerOpts, handler.testingKnobs.balancerOpts...) + } + handler.balancer, err = balancer.NewBalancer(ctx, stopper, balancerMetrics, handler.directoryCache, balancerOpts...) if err != nil { return nil, err } @@ -406,22 +437,26 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } // startPodWatcher runs on a background goroutine and listens to pod change -// notifications. When a pod enters the DRAINING state, connections to that pod -// are subject to an idle timeout that closes them after a short period of -// inactivity. If a pod transitions back to the RUNNING state or to the DELETING -// state, then the idle timeout needs to be cleared. -// -// TODO(jaylim-crl): Update comment above. +// notifications. When a pod transitions into the DRAINING state, a rebalance +// operation will be attempted for that particular pod's tenant. func (handler *proxyHandler) startPodWatcher(ctx context.Context, podWatcher chan *tenant.Pod) { for { select { case <-ctx.Done(): return - case <-podWatcher: - // TODO(jaylim-crl): Invoke rebalance logic here whenever we see - // a new SQL pod. + case pod := <-podWatcher: + // For better responsiveness, we only care about RUNNING pods. + // DRAINING pods can wait until the next rebalance tick. // - // Do nothing for now. + // Note that there's no easy way to tell whether a SQL pod is new + // since this may race with fetchPodsLocked in tenantEntry, so we + // will just attempt to rebalance whenever we see a RUNNING pod + // event. In most cases, this should only occur when a new SQL pod + // gets added (i.e. stamped), or a DRAINING pod transitions to a + // RUNNING pod. + if pod.State == tenant.RUNNING && pod.TenantID != 0 { + handler.balancer.RebalanceTenant(ctx, roachpb.MakeTenantID(pod.TenantID)) + } } } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 26c0f8967c6a..192d5502ee1a 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -16,6 +16,7 @@ import ( "io/ioutil" "net" "os" + "sort" "strings" "sync" "sync/atomic" @@ -25,6 +26,7 @@ import ( "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/balancer" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr" @@ -763,6 +765,121 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestPodWatcher(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + // Start KV server, and enable session migration. + params, _ := tests.CreateTestServerParams() + s, mainDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + _, err := mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") + require.NoError(t, err) + + // Start four SQL pods for the test tenant. + var addresses []string + tenantID := serverutils.TestTenantID() + const podCount = 4 + for i := 0; i < podCount; i++ { + params := tests.CreateTestTenantParams(tenantID) + // The first SQL pod will create the tenant keyspace in the host. + if i != 0 { + params.Existing = true + } + tenant, tenantDB := serverutils.StartTenant(t, s, params) + tenant.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant.Stopper().Stop(ctx) + + // Create a test user. We only need to do it once. + if i == 0 { + _, err = tenantDB.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") + require.NoError(t, err) + _, err = tenantDB.Exec("GRANT admin TO testuser") + require.NoError(t, err) + } + tenantDB.Close() + + addresses = append(addresses, tenant.SQLAddr()) + } + + // Register only 3 SQL pods in the directory server. We will add the 4th + // once the watcher has been established. + tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), nil /* timeSource */) + tds.CreateTenant(tenantID, "tenant-cluster") + for i := 0; i < 3; i++ { + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: addresses[i], + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + }) + } + require.NoError(t, tds.Start(ctx)) + + opts := &ProxyOptions{SkipVerify: true} + opts.testingKnobs.directoryServer = tds + opts.testingKnobs.balancerOpts = []balancer.Option{ + balancer.NoRebalanceLoop(), + balancer.RebalanceRate(1.0), + } + proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + + // Open 12 connections to it. The balancer should distribute the connections + // evenly across 3 SQL pods (i.e. 4 connections each). + dist := map[string]int{} + var conns []*gosql.DB + for i := 0; i < 12; i++ { + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + addr := queryAddr(ctx, t, db) + dist[addr]++ + conns = append(conns, db) + } + + // Validate that connections are balanced evenly (i.e. 12/3 = 4). + for _, d := range dist { + require.Equal(t, 4, d) + } + + // Register the 4th pod. This should emit an event to the pod watcher, which + // triggers rebalancing. Based on the balancer's algorithm, balanced is + // defined as [2, 4] connections. As a result, 2 connections will be moved + // to the new pod. Note that for testing, we set customRebalanceRate to 1.0. + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: addresses[3], + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + }) + + // Wait until two connections have been migrated. + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.ConnMigrationSuccessCount.Count() >= 2 { + return nil + } + return errors.New("waiting for connection migration") + }) + + // Reset distribution and count again. + dist = map[string]int{} + for _, c := range conns { + addr := queryAddr(ctx, t, c) + dist[addr]++ + } + + // Validate distribution. + var counts []int + for _, d := range dist { + counts = append(counts, d) + } + sort.Ints(counts) + require.Equal(t, []int{2, 3, 3, 4}, counts) +} + func TestConnectionMigration(t *testing.T) { defer leaktest.AfterTest(t)() ctx := context.Background() @@ -773,9 +890,7 @@ func TestConnectionMigration(t *testing.T) { defer s.Stopper().Stop(ctx) tenantID := serverutils.TestTenantID() - // TODO(rafi): use ALTER TENANT ALL when available. - _, err := mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES - (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + _, err := mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") require.NoError(t, err) // Start first SQL pod. @@ -806,23 +921,6 @@ func TestConnectionMigration(t *testing.T) { connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) - type queryer interface { - QueryRowContext(context.Context, string, ...interface{}) *gosql.Row - } - // queryAddr queries the SQL node that `db` is connected to for its address. - queryAddr := func(t *testing.T, ctx context.Context, db queryer) string { - t.Helper() - var host, port string - require.NoError(t, db.QueryRowContext(ctx, ` - SELECT - a.value AS "host", b.value AS "port" - FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b - WHERE a.component = 'DB' AND a.field = 'Host' - AND b.component = 'DB' AND b.field = 'Port' - `).Scan(&host, &port)) - return fmt.Sprintf("%s:%s", host, port) - } - // validateMiscMetrics ensures that our invariant of // attempts = success + error_recoverable + error_fatal is valid, and all // other transfer related metrics were incremented as well. @@ -887,7 +985,7 @@ func TestConnectionMigration(t *testing.T) { } t.Run("normal_transfer", func(t *testing.T) { - require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant1.SQLAddr(), queryAddr(tCtx, t, db)) _, err = db.Exec("SET application_name = 'foo'") require.NoError(t, err) @@ -895,7 +993,7 @@ func TestConnectionMigration(t *testing.T) { // Show that we get alternating SQL pods when we transfer. require.NoError(t, f.TransferConnection()) require.Equal(t, int64(1), f.metrics.ConnMigrationSuccessCount.Count()) - require.Equal(t, tenant2.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant2.SQLAddr(), queryAddr(tCtx, t, db)) var name string require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) @@ -906,7 +1004,7 @@ func TestConnectionMigration(t *testing.T) { require.NoError(t, f.TransferConnection()) require.Equal(t, int64(2), f.metrics.ConnMigrationSuccessCount.Count()) - require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + require.Equal(t, tenant1.SQLAddr(), queryAddr(tCtx, t, db)) require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) require.Equal(t, "bar", name) @@ -929,7 +1027,7 @@ func TestConnectionMigration(t *testing.T) { // This loop will run approximately 5 seconds. var tenant1Addr, tenant2Addr int for i := 0; i < 100; i++ { - addr := queryAddr(t, tCtx, db) + addr := queryAddr(tCtx, t, db) if addr == tenant1.SQLAddr() { tenant1Addr++ } else { @@ -962,7 +1060,7 @@ func TestConnectionMigration(t *testing.T) { // transfers should not close the connection. t.Run("failed_transfers_with_tx", func(t *testing.T) { initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() - initAddr := queryAddr(t, tCtx, db) + initAddr := queryAddr(tCtx, t, db) err = crdb.ExecuteTx(tCtx, db, nil /* txopts */, func(tx *gosql.Tx) error { // Run multiple times to ensure that connection isn't closed. @@ -974,7 +1072,7 @@ func TestConnectionMigration(t *testing.T) { if !assert.Regexp(t, "cannot serialize", err.Error()) { return errors.Wrap(err, "non-serialization error") } - addr := queryAddr(t, tCtx, tx) + addr := queryAddr(tCtx, t, tx) if initAddr != addr { return errors.Newf( "address does not match, expected %s, found %s", @@ -996,7 +1094,7 @@ func TestConnectionMigration(t *testing.T) { // Once the transaction is closed, transfers should work. require.NoError(t, f.TransferConnection()) - require.NotEqual(t, initAddr, queryAddr(t, tCtx, db)) + require.NotEqual(t, initAddr, queryAddr(tCtx, t, db)) require.Nil(t, f.ctx.Err()) require.Equal(t, initSuccessCount+1, f.metrics.ConnMigrationSuccessCount.Count()) require.Equal(t, int64(5), f.metrics.ConnMigrationErrorRecoverableCount.Count()) @@ -1011,7 +1109,7 @@ func TestConnectionMigration(t *testing.T) { t.Run("failed_transfers_with_dial_issues", func(t *testing.T) { initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() initErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() - initAddr := queryAddr(t, tCtx, db) + initAddr := queryAddr(tCtx, t, db) // Set the delay longer than the timeout. lookupAddrDelayDuration = 10 * time.Second @@ -1020,7 +1118,7 @@ func TestConnectionMigration(t *testing.T) { err := f.TransferConnection() require.Error(t, err) require.Regexp(t, "injected delays", err.Error()) - require.Equal(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, initAddr, queryAddr(tCtx, t, db)) require.Nil(t, f.ctx.Err()) require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) @@ -1618,3 +1716,21 @@ func mustGetTestSimpleDirectoryServer( require.True(t, ok) return svr } + +type queryer interface { + QueryRowContext(context.Context, string, ...interface{}) *gosql.Row +} + +// queryAddr queries the SQL node that `db` is connected to for its address. +func queryAddr(ctx context.Context, t *testing.T, db queryer) string { + t.Helper() + var host, port string + require.NoError(t, db.QueryRowContext(ctx, ` + SELECT + a.value AS "host", b.value AS "port" + FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b + WHERE a.component = 'DB' AND a.field = 'Host' + AND b.component = 'DB' AND b.field = 'Port' + `).Scan(&host, &port)) + return fmt.Sprintf("%s:%s", host, port) +} diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel index 7b570b7d2d25..0f3da2561c31 100644 --- a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -63,6 +63,7 @@ go_test( "//pkg/security/securitytest", "//pkg/server", "//pkg/sql", + "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", "//pkg/testutils/testcluster", @@ -70,9 +71,12 @@ go_test( "//pkg/util/log", "//pkg/util/randutil", "//pkg/util/stop", + "//pkg/util/timeutil", + "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//require", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//credentials/insecure", "@org_golang_google_grpc//status", ], ) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go index 175a69d1c868..095b66c529b4 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_cache.go +++ b/pkg/ccl/sqlproxyccl/tenant/directory_cache.go @@ -416,7 +416,12 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e continue } - // If caller is watching pods, send to its channel now. + // Update the directory entry for the tenant with the latest + // information about this pod. + d.updateTenantEntry(ctx, resp.Pod) + + // If caller is watching pods, send to its channel now. Only do this + // after updating the tenant entry in the directory. if d.options.podWatcher != nil { select { case d.options.podWatcher <- resp.Pod: @@ -424,10 +429,6 @@ func (d *directoryCache) watchPods(ctx context.Context, stopper *stop.Stopper) e return } } - - // Update the directory entry for the tenant with the latest - // information about this pod. - d.updateTenantEntry(ctx, resp.Pod) } }) if err != nil { diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go b/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go index 380df7ca2522..45f929c9f9c2 100644 --- a/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go +++ b/pkg/ccl/sqlproxyccl/tenant/directory_cache_test.go @@ -21,14 +21,18 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" ) @@ -67,127 +71,109 @@ func TestDirectoryErrors(t *testing.T) { func TestWatchPods(t *testing.T) { defer leaktest.AfterTest(t)() - defer log.ScopeWithoutShowLogs(t).Close(t) - skip.UnderDeadlockWithIssue(t, 71365) + defer log.Scope(t).Close(t) + ctx := context.Background() // Make pod watcher channel. podWatcher := make(chan *tenant.Pod, 1) - // Create the directory. - ctx := context.Background() - tc, dir, tds := newTestDirectoryCache(t, tenant.PodWatcher(podWatcher)) - defer tc.Stopper().Stop(ctx) + // Setup test directory cache and server. + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + dir, tds := setupTestDirectory(t, ctx, stopper, nil /* timeSource */, tenant.PodWatcher(podWatcher)) + + // Wait until the watcher has been established. + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() == 0 { + return errors.New("watchers have not been established yet") + } + return nil + }) tenantID := roachpb.MakeTenantID(20) - require.NoError(t, createTenant(tc, tenantID)) + tds.CreateTenant(tenantID, "my-tenant") + + // Add a new pod to the tenant. + runningPod := &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: "127.0.0.10:10", + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + } + require.True(t, tds.AddPod(tenantID, runningPod)) + pod := <-podWatcher + require.Equal(t, runningPod, pod) - // Call LookupTenantPods to start a new tenant and create an entry. - pods, err := dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should have already been updated. + pods, err := dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr := pods[0].Addr - - // Ensure that correct event was sent to watcher channel. - pod := <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Len(t, pods, 1) + require.Equal(t, runningPod, pods[0]) - // Trigger drain of pod. - tds.Drain() + // Drain the pod. + require.True(t, tds.DrainPod(tenantID, runningPod.Addr)) pod = <-podWatcher require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) + require.Equal(t, runningPod.Addr, pod.Addr) require.Equal(t, tenant.DRAINING, pod.State) require.False(t, pod.StateTimestamp.IsZero()) - // Now shut the tenant directory down. - processes := tds.Get(tenantID) - require.NotNil(t, processes) - require.Len(t, processes, 1) - // Stop the tenant and ensure its IP address is removed from the directory. - for _, process := range processes { - process.Stopper.Stop(ctx) - } - - // Ensure that correct event was sent to watcher channel. - pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.DELETING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) - - // Ensure that all addresses have been cleared from the directory, since - // it should only return RUNNING or DRAINING addresses. - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - return len(tenantPods) == 0 - }, 10*time.Second, 100*time.Millisecond) - - // Resume tenant again by a direct call to the directory server - _, err = tds.EnsurePod(ctx, &tenant.EnsurePodRequest{tenantID.ToUint64()}) + // Directory cache should be updated with the DRAINING pod. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - - // Wait for background watcher to populate the initial pod. - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - if len(tenantPods) != 0 { - addr = tenantPods[0].Addr - return true + require.Len(t, pods, 1) + require.Equal(t, pod, pods[0]) + + // Trigger the directory server to restart. WatchPods should handle + // reconnection properly. + // + // NOTE: We check for the number of listeners before proceeding with the + // pod update (e.g. AddPod) because if we don't do that, there could be a + // situation where AddPod gets called before the watcher gets established, + // which means that the pod update event will never get emitted. This is + // only a test directory server issue due to its simple implementation. One + // way to solve this nicely is to implement checkpointing based on the sent + // updates (just like how Kubernetes bookmarks work). + tds.Stop(ctx) + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() != 0 { + return errors.New("watchers have not been removed yet") } - return false - }, 10*time.Second, 100*time.Millisecond) + return nil + }) + require.NoError(t, tds.Start(ctx)) + testutils.SucceedsSoon(t, func() error { + if tds.WatchListenersCount() == 0 { + return errors.New("watchers have not been established yet") + } + return nil + }) - // Ensure that correct event was sent to watcher channel. + // Put the same pod back to running. + require.True(t, tds.AddPod(tenantID, runningPod)) pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Equal(t, runningPod, pod) - // Verify that LookupTenantPods returns the pod's IP address. - pods, err = dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should be updated with the RUNNING pod. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr = pods[0].Addr - - processes = tds.Get(tenantID) - require.NotNil(t, processes) - require.Len(t, processes, 1) - for dirAddr := range processes { - require.Equal(t, addr, dirAddr.String()) - } - - // Stop the tenant and ensure its IP address is removed from the directory. - for _, process := range processes { - process.Stopper.Stop(ctx) - } - - require.Eventually(t, func() bool { - tenantPods, _ := dir.TryLookupTenantPods(ctx, tenantID) - return len(tenantPods) == 0 - }, 10*time.Second, 100*time.Millisecond) + require.Len(t, pods, 1) + require.Equal(t, pod, pods[0]) - // Ensure that correct event was sent to watcher channel. + // Delete the pod. + require.True(t, tds.RemovePod(tenantID, runningPod.Addr)) pod = <-podWatcher require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) + require.Equal(t, runningPod.Addr, pod.Addr) require.Equal(t, tenant.DELETING, pod.State) require.False(t, pod.StateTimestamp.IsZero()) - // Verify that a new call to LookupTenantPods will resume again the tenant. - pods, err = dir.LookupTenantPods(ctx, tenantID, "") + // Directory cache should have no pods. + pods, err = dir.TryLookupTenantPods(ctx, tenantID) require.NoError(t, err) - require.NotEmpty(t, pods) - addr = pods[0].Addr - - // Ensure that correct event was sent to watcher channel. - pod = <-podWatcher - require.Equal(t, tenantID.ToUint64(), pod.TenantID) - require.Equal(t, addr, pod.Addr) - require.Equal(t, tenant.RUNNING, pod.State) - require.False(t, pod.StateTimestamp.IsZero()) + require.Empty(t, pods) + stopper.Stop(ctx) + stopper.Stop(ctx) } func TestCancelLookups(t *testing.T) { @@ -434,6 +420,40 @@ func startTenant( return &tenantdirsvr.Process{SQL: sqlAddr, Stopper: tenantStopper}, nil } +// setupTestDirectory returns an instance of the directory cache and the +// in-memory test static directory server. Tenants will need to be added/removed +// manually. +func setupTestDirectory( + t *testing.T, + ctx context.Context, + stopper *stop.Stopper, + timeSource timeutil.TimeSource, + opts ...tenant.DirOption, +) (tenant.DirectoryCache, *tenantdirsvr.TestStaticDirectoryServer) { + t.Helper() + + // Start an in-memory static directory server. + directoryServer := tenantdirsvr.NewTestStaticDirectoryServer(stopper, timeSource) + require.NoError(t, directoryServer.Start(ctx)) + + // Dial the test directory server. + conn, err := grpc.DialContext( + ctx, + "", + grpc.WithContextDialer(directoryServer.DialerFunc), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { + _ = conn.Close() // nolint:grpcconnclose + })) + client := tenant.NewDirectoryClient(conn) + directoryCache, err := tenant.NewDirectoryCache(ctx, stopper, client, opts...) + require.NoError(t, err) + + return directoryCache, directoryServer +} + // Setup directory cache that uses a client connected to a test directory server // that manages tenants connected to a backing KV server. func newTestDirectoryCache( @@ -464,7 +484,7 @@ func newTestDirectoryCache( return process, nil } - listenPort, err := net.Listen("tcp", ":0") + listenPort, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = tds.Serve(listenPort) }() diff --git a/pkg/ccl/sqlproxyccl/tenant/entry.go b/pkg/ccl/sqlproxyccl/tenant/entry.go index e68cc0ff8c32..b3e99d246bf7 100644 --- a/pkg/ccl/sqlproxyccl/tenant/entry.go +++ b/pkg/ccl/sqlproxyccl/tenant/entry.go @@ -239,6 +239,7 @@ func (e *tenantEntry) fetchPodsLocked( ctx context.Context, client DirectoryClient, ) (tenantPods []*Pod, err error) { // List the pods for the given tenant. + // // TODO(andyk): This races with the pod watcher, which may receive updates // that are newer than what ListPods returns. This could be fixed by adding // version values to the pods in order to detect races. diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel index c577e2e78f61..f56109730cd3 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "in_mem_listener.go", "test_directory_svr.go", "test_simple_directory_svr.go", + "test_static_directory_svr.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenantdirsvr", visibility = ["//visibility:public"], diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go index cd8c780b2dfa..3f387ec8ec98 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go @@ -48,7 +48,7 @@ type Process struct { // is a possibility that between the two calls, the parent stopper completes a // stop and then the leak detection may find a leaked stopper. func NewSubStopper(parentStopper *stop.Stopper) *stop.Stopper { - mu := &syncutil.Mutex{} + var mu syncutil.Mutex var subStopper *stop.Stopper parentStopper.AddCloser(stop.CloserFn(func() { mu.Lock() diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go index 6961a84407cf..9f0d5565c6aa 100644 --- a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go @@ -44,7 +44,7 @@ type TestSimpleDirectoryServer struct { var _ tenant.DirectoryServer = &TestSimpleDirectoryServer{} // NewTestSimpleDirectoryServer constructs a new simple directory server. -func NewTestSimpleDirectoryServer(podAddr string) (tenant.DirectoryServer, *grpc.Server) { +func NewTestSimpleDirectoryServer(podAddr string) (*TestSimpleDirectoryServer, *grpc.Server) { dir := &TestSimpleDirectoryServer{podAddr: podAddr} dir.mu.deleted = make(map[roachpb.TenantID]struct{}) grpcServer := grpc.NewServer() diff --git a/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go new file mode 100644 index 000000000000..ed40d2113435 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go @@ -0,0 +1,432 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenantdirsvr + +import ( + "container/list" + "context" + "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/test/bufconn" +) + +// TestStaticDirectoryServer is a directory server that stores a static mapping +// of tenants to their pods. Callers will need to invoke operations on the +// directory server to create/delete/update tenants and pods as necessary. +// Unlike the regular directory server that automatically spins up SQL pods when +// one isn't present, the static directory server does not do that. The caller +// should start the SQL pods up, and register them with the directory server. +type TestStaticDirectoryServer struct { + // rootStopper is used to create sub-stoppers for directory instances. + rootStopper *stop.Stopper + + // timeSource is the source of the time. By default, this will be set to + // timeutil.DefaultTimeSource. + timeSource timeutil.TimeSource + + // process corresponds to fields that need to be updated together whenever + // the test directory is started or stopped. Note that when the test + // directory server gets restarted, we do not reset the tenant data. + process struct { + syncutil.Mutex + + // stopper is used to start async tasks within the directory server. + stopper *stop.Stopper + + // ln corresponds to the in-memory listener for the directory + // server. This will be used to construct a dialer function. + ln *bufconn.Listener + + // grpcServer corresponds to the GRPC server instance for the + // directory server. + grpcServer *grpc.Server + } + + mu struct { + syncutil.Mutex + + // tenants stores a list of pods for every tenant. If a tenant does not + // exist in the map, the tenant is assumed to be non-existent. Since + // this is a test directory server, we do not care about fine-grained + // locking here. + tenants map[roachpb.TenantID][]*tenant.Pod + + // tenantNames stores a list of clusterNames associated with each tenant. + tenantNames map[roachpb.TenantID]string + + // eventListeners stores a list of listeners that are watching changes + // to pods through WatchPods. + eventListeners *list.List + } +} + +var _ tenant.DirectoryServer = &TestStaticDirectoryServer{} + +// NewTestStaticDirectoryServer constructs a new static directory server. +func NewTestStaticDirectoryServer( + stopper *stop.Stopper, timeSource timeutil.TimeSource, +) *TestStaticDirectoryServer { + if timeSource == nil { + timeSource = timeutil.DefaultTimeSource{} + } + dir := &TestStaticDirectoryServer{rootStopper: stopper, timeSource: timeSource} + dir.mu.tenants = make(map[roachpb.TenantID][]*tenant.Pod) + dir.mu.tenantNames = make(map[roachpb.TenantID]string) + dir.mu.eventListeners = list.New() + return dir +} + +// ListPods returns a list with all SQL pods associated with the given tenant. +// If the tenant does not exists, no pods will be returned. +// +// ListPods implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) ListPods( + ctx context.Context, req *tenant.ListPodsRequest, +) (*tenant.ListPodsResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[roachpb.MakeTenantID(req.TenantID)] + if !ok { + return &tenant.ListPodsResponse{}, nil + } + + // Return a copy of all the pods to avoid any race issues. + tenantPods := make([]*tenant.Pod, len(pods)) + for i, pod := range pods { + copyPod := *pod + tenantPods[i] = ©Pod + } + return &tenant.ListPodsResponse{Pods: tenantPods}, nil +} + +// WatchPods allows callers to monitor for pod update events. +// +// WatchPods implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) WatchPods( + req *tenant.WatchPodsRequest, server tenant.Directory_WatchPodsServer, +) error { + d.process.Lock() + stopper := d.process.stopper + d.process.Unlock() + + // This cannot happen unless WatchPods was called directly, which we + // shouldn't since it is meant to be called through a GRPC client. + if stopper == nil { + return status.Errorf(codes.FailedPrecondition, "directory server has not been started") + } + + addListener := func(ch chan *tenant.WatchPodsResponse) *list.Element { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.PushBack(ch) + } + removeListener := func(e *list.Element) chan *tenant.WatchPodsResponse { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.Remove(e).(chan *tenant.WatchPodsResponse) + } + + // Construct the channel with a small buffer to allow for a burst of + // notifications, and a slow receiver. + c := make(chan *tenant.WatchPodsResponse, 10) + chElement := addListener(c) + + return stopper.RunTask( + context.Background(), + "watch-pods-server", + func(ctx context.Context) { + defer func() { + if ch := removeListener(chElement); ch != nil { + close(ch) + } + }() + + for watch := true; watch; { + select { + case e, ok := <-c: + // Channel was closed. + if !ok { + watch = false + break + } + if err := server.Send(e); err != nil { + watch = false + } + case <-stopper.ShouldQuiesce(): + watch = false + } + } + }, + ) +} + +// EnsurePod returns an empty response if a tenant with the given tenant ID +// exists, and there is at least one SQL pod. If there are no SQL pods, a GRPC +// FailedPrecondition error will be returned. Similarly, if the tenant does not +// exists, a GRPC NotFound error will be returned. This would mimic the behavior +// that we have in the actual tenant directory. +// +// EnsurePod implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) EnsurePod( + ctx context.Context, req *tenant.EnsurePodRequest, +) (*tenant.EnsurePodResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[roachpb.MakeTenantID(req.TenantID)] + if !ok { + return nil, status.Errorf(codes.NotFound, "tenant does not exist") + } + if len(pods) == 0 { + return nil, status.Errorf(codes.FailedPrecondition, "tenant has no pods") + } + return &tenant.EnsurePodResponse{}, nil +} + +// GetTenant returns tenant metadata associated with the given tenant ID. If the +// tenant isn't in the directory server, a GRPC NotFound error will be returned. +// +// GetTenant implements the tenant.DirectoryServer interface. +func (d *TestStaticDirectoryServer) GetTenant( + ctx context.Context, req *tenant.GetTenantRequest, +) (*tenant.GetTenantResponse, error) { + d.mu.Lock() + defer d.mu.Unlock() + + tenantID := roachpb.MakeTenantID(req.TenantID) + if _, ok := d.mu.tenants[tenantID]; !ok { + return nil, status.Errorf(codes.NotFound, "tenant does not exist") + } + return &tenant.GetTenantResponse{ClusterName: d.mu.tenantNames[tenantID]}, nil +} + +// CreateTenant creates a tenant with the given tenant ID in the directory +// server. If the tenant already exists, this is a no-op. +func (d *TestStaticDirectoryServer) CreateTenant(tenantID roachpb.TenantID, clusterName string) { + d.mu.Lock() + defer d.mu.Unlock() + + if _, ok := d.mu.tenants[tenantID]; ok { + return + } + d.mu.tenants[tenantID] = make([]*tenant.Pod, 0) + d.mu.tenantNames[tenantID] = clusterName +} + +// DeleteTenant ensures that the tenant with the given tenant ID has been +// removed from the directory server. Doing this would return a NotFound error +// for certain directory server endpoints. This also changes the behavior of +// ListPods so no pods are returned for the given tenant. +func (d *TestStaticDirectoryServer) DeleteTenant(tenantID roachpb.TenantID) { + d.mu.Lock() + defer d.mu.Unlock() + + pods, ok := d.mu.tenants[tenantID] + if !ok { + return + } + for _, pod := range pods { + // Update pod to DELETING, and emit event. + pod.State = tenant.DELETING + pod.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(pod) + } + delete(d.mu.tenants, tenantID) + delete(d.mu.tenantNames, tenantID) +} + +// AddPod adds the pod to the given tenant pod's list. If a SQL pod with the +// same address already exists, this updates the existing pod. AddPod returns +// true if the operation was successful, and false otherwise. +// +// NOTE: pod has to be fully populated, and pod.TenantID should be the same as +// the given tenantID. +func (d *TestStaticDirectoryServer) AddPod(tenantID roachpb.TenantID, pod *tenant.Pod) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // Emit an event that the pod has been created. + d.notifyPodUpdateLocked(pod) + + // Check if the pod exists. This would handle pods transitioning from + // DRAINING to RUNNING. + for i, existing := range pods { + if existing.Addr == pod.Addr { + d.mu.tenants[tenantID][i] = pod + return true + } + } + + // A new pod has been added. + d.mu.tenants[tenantID] = append(d.mu.tenants[tenantID], pod) + return true +} + +// DrainPod puts the tenant associated SQL pod with the given address into the +// DRAINING state. DrainPod returns true if the operation was successful, and +// false otherwise. +func (d *TestStaticDirectoryServer) DrainPod(tenantID roachpb.TenantID, podAddr string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // If the pod exists, update its state to DRAINING. + for _, existing := range pods { + if existing.Addr == podAddr { + existing.State = tenant.DRAINING + existing.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(existing) + return true + } + } + return false +} + +// RemovePod deletes the SQL pod with the given address from the associated +// tenant. RemovePod returns true if the operation was successful, and false +// otherwise. +func (d *TestStaticDirectoryServer) RemovePod(tenantID roachpb.TenantID, podAddr string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // Tenant does not exist. + pods, ok := d.mu.tenants[tenantID] + if !ok { + return false + } + + // If the pod exists, remove it. + for i, existing := range pods { + if existing.Addr == podAddr { + // Remove pod. + copy(d.mu.tenants[tenantID][i:], d.mu.tenants[tenantID][i+1:]) + d.mu.tenants[tenantID] = d.mu.tenants[tenantID][:len(d.mu.tenants[tenantID])-1] + + // Update pod to DELETING, and emit event. + existing.State = tenant.DELETING + existing.StateTimestamp = d.timeSource.Now() + d.notifyPodUpdateLocked(existing) + return true + } + } + return false +} + +// Start starts the test directory server using an in-memory listener. This +// returns an error if the server cannot be started. If the sevrer has already +// been started, this is a no-op. +func (d *TestStaticDirectoryServer) Start(ctx context.Context) error { + d.process.Lock() + defer d.process.Unlock() + + // Server has already been started. + if d.process.ln != nil { + return nil + } + + stopper := NewSubStopper(d.rootStopper) + grpcServer := grpc.NewServer() + tenant.RegisterDirectoryServer(grpcServer, d) + ln, err := ListenAndServeInMemGRPC(ctx, stopper, grpcServer) + if err != nil { + return err + } + + // Update instance fields together. + d.process.stopper = stopper + d.process.ln = ln + d.process.grpcServer = grpcServer + return nil +} + +// Stop stops the test directory server instance. If the server has already +// been stopped, this is a no-op. +func (d *TestStaticDirectoryServer) Stop(ctx context.Context) { + d.process.Lock() + defer d.process.Unlock() + + stopper, ln, grpcServer := d.process.stopper, d.process.ln, d.process.grpcServer + if ln == nil { + return + } + + d.process.stopper = nil + d.process.ln = nil + d.process.grpcServer = nil + + // Close listener first, followed by GRPC server, and stopper. + _ = ln.Close() + grpcServer.Stop() + stopper.Stop(ctx) +} + +// DialerFunc corresponds to the dialer function used to dial the test directory +// server. We do this because the test directory server runs in-memory, and +// does not bind to a physical network address. This will be used within +// grpc.WithContextDialer. +func (d *TestStaticDirectoryServer) DialerFunc(ctx context.Context, addr string) (net.Conn, error) { + d.process.Lock() + listener := d.process.ln + d.process.Unlock() + + if listener == nil { + return nil, errors.New("directory server has not been started") + } + return listener.DialContext(ctx) +} + +// WatchListenersCount returns the number of active listeners for pod update +// events. +func (d *TestStaticDirectoryServer) WatchListenersCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.mu.eventListeners.Len() +} + +// notifyPodUpdateLocked sends a pod update event to all WatchPods listeners. +func (d *TestStaticDirectoryServer) notifyPodUpdateLocked(pod *tenant.Pod) { + // Make a copy of the pod to prevent race issues. + copyPod := *pod + res := &tenant.WatchPodsResponse{Pod: ©Pod} + + for e := d.mu.eventListeners.Front(); e != nil; { + select { + case e.Value.(chan *tenant.WatchPodsResponse) <- res: + e = e.Next() + default: + // The receiver is unable to consume fast enough. Close the channel + // and remove it from the list. + eToClose := e + e = e.Next() + ch := d.mu.eventListeners.Remove(eToClose) + close(ch.(chan *tenant.WatchPodsResponse)) + } + } +} diff --git a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go index 7946c7e96d48..a46962bf1e0f 100644 --- a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go +++ b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go @@ -36,9 +36,7 @@ func TestShowTransferState(t *testing.T) { _, err := tenantDB.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") require.NoError(t, err) - // TODO(rafi): use ALTER TENANT ALL when available. - _, err = mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES - (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + _, err = mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") require.NoError(t, err) t.Run("without_transfer_key", func(t *testing.T) { diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index 53e654b99d6f..38b7fbea166e 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -866,6 +866,7 @@ func TestLint(t *testing.T) { ":!util/tracing/*_test.go", ":!ccl/sqlproxyccl/tenantdirsvr/test_directory_svr.go", ":!ccl/sqlproxyccl/tenantdirsvr/test_simple_directory_svr.go", + ":!ccl/sqlproxyccl/tenantdirsvr/test_static_directory_svr.go", ) if err != nil { t.Fatal(err)