Skip to content

Commit

Permalink
feat(cluster): add SessionConfigOption
Browse files Browse the repository at this point in the history
Unfortunately, there is no way in gocql to specify the host to which given query should be routed. This functionality can be achieved by using the SingleHostSessionConfigOption. It will be needed when executing raft read barrier after snapshot and before desc schema stages in backup.
  • Loading branch information
Michal-Leszczynski authored and karol-kokoszka committed Jun 21, 2024
1 parent 4cf127e commit 6f8d6fc
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 32 deletions.
3 changes: 2 additions & 1 deletion pkg/service/backup/service_backup_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/scylladb/go-set/strset"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/gocqlx/v2/qb"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"go.uber.org/atomic"
"go.uber.org/zap/zapcore"

Expand Down Expand Up @@ -121,7 +122,7 @@ func newTestServiceWithUser(t *testing.T, session gocqlx.Session, client *scylla
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
logger.Named("backup"),
Expand Down
84 changes: 61 additions & 23 deletions pkg/service/cluster/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,83 +583,121 @@ func (s *Service) ListNodes(ctx context.Context, clusterID uuid.UUID) ([]Node, e
return nodes, nil
}

// SessionConfigOption defines function modifying cluster config that can be used when creating session.
type SessionConfigOption func(ctx context.Context, clusterID uuid.UUID, client *scyllaclient.Client, cfg *gocql.ClusterConfig) error

// SingleHostSessionConfigOption ensures that session will be connected only to the single, provided host.
func SingleHostSessionConfigOption(host string) SessionConfigOption {
return func(ctx context.Context, clusterID uuid.UUID, client *scyllaclient.Client, cfg *gocql.ClusterConfig) error {
ni, err := client.NodeInfo(ctx, host)
if err != nil {
return errors.Wrapf(err, "fetch node (%s) info", host)
}
cqlAddr := ni.CQLAddr(host)
cfg.Hosts = []string{cqlAddr}
cfg.HostFilter = gocql.WhiteListHostFilter(cqlAddr)
cfg.DisableInitialHostLookup = true
return nil
}
}

// SessionFunc returns CQL session for given cluster ID.
type SessionFunc func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error)
type SessionFunc func(ctx context.Context, clusterID uuid.UUID, opts ...SessionConfigOption) (gocqlx.Session, error)

// GetSession returns CQL session to provided cluster.
func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID) (session gocqlx.Session, err error) {
s.logger.Debug(ctx, "GetSession", "cluster_id", clusterID)
func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID, opts ...SessionConfigOption) (session gocqlx.Session, err error) {
s.logger.Info(ctx, "Get session", "cluster_id", clusterID)

client, err := s.CreateClientNoCache(ctx, clusterID)
if err != nil {
return session, errors.Wrap(err, "get client")
}
defer logutil.LogOnError(ctx, s.logger, client.Close, "Couldn't close scylla client")

cfg := gocql.NewCluster()
for _, opt := range opts {
if err := opt(ctx, clusterID, client, cfg); err != nil {
return session, err
}
}
// Fill hosts if they weren't specified by the options
if len(cfg.Hosts) == 0 {
sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts)
if err != nil {
s.logger.Info(ctx, "Gets session", "err", err)
if errors.Is(err, ErrNoRPCAddressesFound) {
return session, err
}
}
cfg.Hosts = sessionHosts
}

ni, err := client.AnyNodeInfo(ctx)
if err != nil {
return session, errors.Wrap(err, "fetch node info")
}

sessionHosts, err := GetRPCAddresses(ctx, client, client.Config().Hosts)
if err != nil {
s.logger.Info(ctx, "GetSession", "err", err)
if errors.Is(err, ErrNoRPCAddressesFound) {
return session, err
}
if err := s.extendClusterConfigWithAuthentication(clusterID, ni, cfg); err != nil {
return session, err
}
if err := s.extendClusterConfigWithTLS(ctx, clusterID, ni, cfg); err != nil {
return session, err
}

scyllaCluster := gocql.NewCluster(sessionHosts...)
cqlPort := ni.CQLPort()
return gocqlx.WrapSession(cfg.CreateSession())
}

func (s *Service) extendClusterConfigWithAuthentication(clusterID uuid.UUID, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error {
if ni.CqlPasswordProtected {
credentials := secrets.CQLCreds{
ClusterID: clusterID,
}
err := s.secretsStore.Get(&credentials)
if errors.Is(err, service.ErrNotFound) {
return session, errors.New("cluster requires CQL authentication but username/password was not set")
return errors.New("cluster requires CQL authentication but username/password was not set")
}
if err != nil {
return session, errors.Wrap(err, "get credentials")
return errors.Wrap(err, "get credentials")
}

scyllaCluster.Authenticator = gocql.PasswordAuthenticator{
cfg.Authenticator = gocql.PasswordAuthenticator{
Username: credentials.Username,
Password: credentials.Password,
}
}
return nil
}

func (s *Service) extendClusterConfigWithTLS(ctx context.Context, clusterID uuid.UUID, ni *scyllaclient.NodeInfo, cfg *gocql.ClusterConfig) error {
cluster, err := s.GetClusterByID(ctx, clusterID)
if err != nil {
return session, errors.Wrap(err, "get cluster by id")
return errors.Wrap(err, "get cluster by id")
}

cqlPort := ni.CQLPort()
if ni.ClientEncryptionEnabled && !cluster.ForceTLSDisabled {
if !cluster.ForceNonSSLSessionPort {
cqlPort = ni.CQLSSLPort()
}
scyllaCluster.SslOpts = &gocql.SslOptions{
cfg.SslOpts = &gocql.SslOptions{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
if ni.ClientEncryptionRequireAuth {
keyPair, err := s.loadTLSIdentity(clusterID)
if err != nil {
return session, err
return err
}
scyllaCluster.SslOpts.Config.Certificates = []tls.Certificate{keyPair}
cfg.SslOpts.Config.Certificates = []tls.Certificate{keyPair}
}
}

p, err := strconv.Atoi(cqlPort)
if err != nil {
return session, errors.Wrap(err, "parse cql port")
return errors.Wrap(err, "parse cql port")
}
scyllaCluster.Port = p

return gocqlx.WrapSession(scyllaCluster.CreateSession())
cfg.Port = p
return nil
}

func (s *Service) loadTLSIdentity(clusterID uuid.UUID) (tls.Certificate, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/service/repair/service_repair_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/scylladb/scylla-manager/v3/pkg/dht"
"github.com/scylladb/scylla-manager/v3/pkg/ping/cqlping"
"github.com/scylladb/scylla-manager/v3/pkg/schema/table"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"github.com/scylladb/scylla-manager/v3/pkg/service/scheduler"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/testhelper"
Expand Down Expand Up @@ -397,7 +398,7 @@ func newTestService(t *testing.T, session gocqlx.Session, client *scyllaclient.C
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return gocqlx.Session{}, errors.New("not implemented")
},
logger.Named("repair"),
Expand All @@ -419,7 +420,7 @@ func newTestServiceWithClusterSession(t *testing.T, session gocqlx.Session, clie
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateSession(t, client), nil
},
logger.Named("repair"),
Expand Down
7 changes: 4 additions & 3 deletions pkg/service/restore/helper_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/scylladb/go-log"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/gocqlx/v2/qb"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"github.com/scylladb/scylla-manager/v3/pkg/util/version"
"go.uber.org/zap/zapcore"

Expand Down Expand Up @@ -121,7 +122,7 @@ func newBackupSvc(t *testing.T, mgrSession gocqlx.Session, client *scyllaclient.
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateSession(t, client), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("backup"),
Expand All @@ -140,7 +141,7 @@ func newRestoreSvc(t *testing.T, mgrSession gocqlx.Session, client *scyllaclient
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateSession(t, client), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("repair"),
Expand All @@ -157,7 +158,7 @@ func newRestoreSvc(t *testing.T, mgrSession gocqlx.Session, client *scyllaclient
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("restore"),
Expand Down
6 changes: 3 additions & 3 deletions pkg/service/restore/service_restore_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func newTestService(t *testing.T, session gocqlx.Session, client *scyllaclient.C
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("repair"),
Expand All @@ -134,7 +134,7 @@ func newTestService(t *testing.T, session gocqlx.Session, client *scyllaclient.C
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("backup"),
Expand All @@ -151,7 +151,7 @@ func newTestService(t *testing.T, session gocqlx.Session, client *scyllaclient.C
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
func(ctx context.Context, clusterID uuid.UUID, _ ...cluster.SessionConfigOption) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
logger.Named("restore"),
Expand Down

0 comments on commit 6f8d6fc

Please sign in to comment.