Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: Get rid of globals for setting a protocol listener factory. Get rid of unused, global-ridden and complicated Interceptor and Option functionality. #2879

Merged
merged 2 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,36 @@ import (
sqle "github.com/dolthub/go-mysql-server"
)

func Intercept(h Interceptor) {
inters = append(inters, h)
sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
}

func WithChain() Option {
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
f := DefaultProtocolListenerFunc
DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
cfg.Handler = buildChain(cfg.Handler)
return f(cfg, sel)
}
}
// InterceptorChain allows an integrator to build a chain of
// |Interceptor| instances which will wrap and intercept the server's
// mysql.Handler.
//
// Example usage:
//
// var ic InterceptorChain
// ic.WithInterceptor(metricsInterceptor)
// ic.WithInterceptor(authInterceptor)
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
type InterceptorChain struct {
inters []Interceptor
}

var inters []Interceptor
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
ic.inters = append(ic.inters, h)
}

func buildChain(h mysql.Handler) mysql.Handler {
func (ic *InterceptorChain) Option() Option {
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
chainHandler := buildChain(handler, ic.inters)
return e, sm, chainHandler
}
}

func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
// XXX: Mutates |inters|
sort.Slice(inters, func(i, j int) bool {
return inters[i].Priority() < inters[j].Priority()
})
var last Chain = h
for i := len(inters) - 1; i >= 0; i-- {
filter := inters[i]
Expand All @@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler {
}

type Interceptor interface {

// Priority returns the priority of the interceptor.
Priority() int

Expand Down Expand Up @@ -88,7 +99,6 @@ type Interceptor interface {
}

type Chain interface {

// ComQuery is called when a connection receives a query.
// Note the contents of the query slice may change after
// the first call to callback. So the Handler should not
Expand Down
43 changes: 22 additions & 21 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ type ProtocolListener interface {
}

// ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)

// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
// this function will change the protocol listener used when creating all servers. If multiple servers are needed
// with different protocols, then create each server after changing this function. Servers retain the protocol that
// they were created with.
var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
return mysql.NewListenerWithConfig(cfg)
type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)

func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
vtListener, err := mysql.NewListenerWithConfig(listenerCfg)
if err != nil {
return nil, err
}
if cfg.Version != "" {
vtListener.ServerVersion = cfg.Version
}
vtListener.TLSConfig = cfg.TLSConfig
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
return vtListener, nil
}

type ServerEventListener interface {
Expand Down Expand Up @@ -114,10 +119,6 @@ func portInUse(hostPort string) bool {
}

func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
for _, option := range cfg.Options {
option(e, sm, handler)
}

if cfg.ConnReadTimeout < 0 {
cfg.ConnReadTimeout = 0
}
Expand All @@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
cfg.MaxConnections = 0
}

for _, opt := range cfg.Options {
e, sm, handler = opt(e, sm, handler)
}

l := cfg.Listener
var unixSocketInUse error
if l == nil {
Expand Down Expand Up @@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
ConnReadBufferSize: mysql.DefaultConnBufferSize,
AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS,
}
protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel)
plf := cfg.ProtocolListenerFactory
if plf == nil {
plf = MySQLProtocolListenerFactory
}
protocolListener, err := plf(cfg, listenerCfg, sel)
if err != nil {
return nil, err
}

if vtListener, ok := protocolListener.(*mysql.Listener); ok {
if cfg.Version != "" {
vtListener.ServerVersion = cfg.Version
}
vtListener.TLSConfig = cfg.TLSConfig
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
}

return &Server{
Listener: protocolListener,
handler: handler,
Expand Down
15 changes: 10 additions & 5 deletions server/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ import (
"go.opentelemetry.io/otel/trace"

gms "github.com/dolthub/go-mysql-server"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
)

// Option is an option to customize server.
type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler)

// Server is a MySQL server for SQLe engines.
type Server struct {
Listener ProtocolListener
Expand All @@ -38,6 +34,9 @@ type Server struct {
Engine *gms.Engine
}

// An option to customize the server.
type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler)

// Config for the mysql server.
type Config struct {
// Protocol for the connection.
Expand Down Expand Up @@ -82,8 +81,14 @@ type Config struct {
// If true, queries will be logged as base64 encoded strings.
// If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces.
EncodeLoggedQuery bool
// Options add additional options to customize the server.
// Options gets a chance to visit and mutate the GMS *Engine,
// *server.SessionManager and the mysql.Handler as the server
// is being initialized, before the ProtocolListener is
// constructed.
Options []Option
// Used to get the ProtocolListener on server start.
// If unset, defaults to MySQLProtocolListenerFactory.
ProtocolListenerFactory ProtocolListenerFunc
}

func (c Config) NewConfig() (Config, error) {
Expand Down
Loading