Skip to content

Commit

Permalink
Merge pull request #2879 from dolthub/aaron/server-protocol-listener-…
Browse files Browse the repository at this point in the history
…fixup

server: Get rid of globals for setting a protocol listener factory. Get rid of unused, global-ridden and complicated Interceptor and Option functionality.
  • Loading branch information
reltuk authored Mar 5, 2025
2 parents d06023d + 1a65d1c commit 14a57e0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 43 deletions.
44 changes: 27 additions & 17 deletions server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,36 @@ import (
sqle "github.com/dolthub/go-mysql-server"
)

func Intercept(h Interceptor) {
inters = append(inters, h)
sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
}

func WithChain() Option {
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
f := DefaultProtocolListenerFunc
DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
cfg.Handler = buildChain(cfg.Handler)
return f(cfg, sel)
}
}
// InterceptorChain allows an integrator to build a chain of
// |Interceptor| instances which will wrap and intercept the server's
// mysql.Handler.
//
// Example usage:
//
// var ic InterceptorChain
// ic.WithInterceptor(metricsInterceptor)
// ic.WithInterceptor(authInterceptor)
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
type InterceptorChain struct {
inters []Interceptor
}

var inters []Interceptor
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
ic.inters = append(ic.inters, h)
}

func buildChain(h mysql.Handler) mysql.Handler {
func (ic *InterceptorChain) Option() Option {
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
chainHandler := buildChain(handler, ic.inters)
return e, sm, chainHandler
}
}

func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
// XXX: Mutates |inters|
sort.Slice(inters, func(i, j int) bool {
return inters[i].Priority() < inters[j].Priority()
})
var last Chain = h
for i := len(inters) - 1; i >= 0; i-- {
filter := inters[i]
Expand All @@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler {
}

type Interceptor interface {

// Priority returns the priority of the interceptor.
Priority() int

Expand Down Expand Up @@ -88,7 +99,6 @@ type Interceptor interface {
}

type Chain interface {

// ComQuery is called when a connection receives a query.
// Note the contents of the query slice may change after
// the first call to callback. So the Handler should not
Expand Down
43 changes: 22 additions & 21 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ type ProtocolListener interface {
}

// ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)

// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
// this function will change the protocol listener used when creating all servers. If multiple servers are needed
// with different protocols, then create each server after changing this function. Servers retain the protocol that
// they were created with.
var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
return mysql.NewListenerWithConfig(cfg)
type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)

func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
vtListener, err := mysql.NewListenerWithConfig(listenerCfg)
if err != nil {
return nil, err
}
if cfg.Version != "" {
vtListener.ServerVersion = cfg.Version
}
vtListener.TLSConfig = cfg.TLSConfig
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
return vtListener, nil
}

type ServerEventListener interface {
Expand Down Expand Up @@ -114,10 +119,6 @@ func portInUse(hostPort string) bool {
}

func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
for _, option := range cfg.Options {
option(e, sm, handler)
}

if cfg.ConnReadTimeout < 0 {
cfg.ConnReadTimeout = 0
}
Expand All @@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
cfg.MaxConnections = 0
}

for _, opt := range cfg.Options {
e, sm, handler = opt(e, sm, handler)
}

l := cfg.Listener
var unixSocketInUse error
if l == nil {
Expand Down Expand Up @@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
ConnReadBufferSize: mysql.DefaultConnBufferSize,
AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS,
}
protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel)
plf := cfg.ProtocolListenerFactory
if plf == nil {
plf = MySQLProtocolListenerFactory
}
protocolListener, err := plf(cfg, listenerCfg, sel)
if err != nil {
return nil, err
}

if vtListener, ok := protocolListener.(*mysql.Listener); ok {
if cfg.Version != "" {
vtListener.ServerVersion = cfg.Version
}
vtListener.TLSConfig = cfg.TLSConfig
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
}

return &Server{
Listener: protocolListener,
handler: handler,
Expand Down
15 changes: 10 additions & 5 deletions server/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ import (
"go.opentelemetry.io/otel/trace"

gms "github.com/dolthub/go-mysql-server"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
)

// Option is an option to customize server.
type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler)

// Server is a MySQL server for SQLe engines.
type Server struct {
Listener ProtocolListener
Expand All @@ -38,6 +34,9 @@ type Server struct {
Engine *gms.Engine
}

// An option to customize the server.
type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler)

// Config for the mysql server.
type Config struct {
// Protocol for the connection.
Expand Down Expand Up @@ -82,8 +81,14 @@ type Config struct {
// If true, queries will be logged as base64 encoded strings.
// If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces.
EncodeLoggedQuery bool
// Options add additional options to customize the server.
// Options gets a chance to visit and mutate the GMS *Engine,
// *server.SessionManager and the mysql.Handler as the server
// is being initialized, before the ProtocolListener is
// constructed.
Options []Option
// Used to get the ProtocolListener on server start.
// If unset, defaults to MySQLProtocolListenerFactory.
ProtocolListenerFactory ProtocolListenerFunc
}

func (c Config) NewConfig() (Config, error) {
Expand Down

0 comments on commit 14a57e0

Please sign in to comment.