diff --git a/server/extension.go b/server/extension.go index 4fb9079c5d..1568efbc87 100644 --- a/server/extension.go +++ b/server/extension.go @@ -27,24 +27,36 @@ import ( sqle "github.com/dolthub/go-mysql-server" ) -func Intercept(h Interceptor) { - inters = append(inters, h) - sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() }) -} - -func WithChain() Option { - return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) { - f := DefaultProtocolListenerFunc - DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { - cfg.Handler = buildChain(cfg.Handler) - return f(cfg, sel) - } - } +// InterceptorChain allows an integrator to build a chain of +// |Interceptor| instances which will wrap and intercept the server's +// mysql.Handler. +// +// Example usage: +// +// var ic InterceptorChain +// ic.WithInterceptor(metricsInterceptor) +// ic.WithInterceptor(authInterceptor) +// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...) +type InterceptorChain struct { + inters []Interceptor } -var inters []Interceptor +func (ic *InterceptorChain) WithInterceptor(h Interceptor) { + ic.inters = append(ic.inters, h) +} -func buildChain(h mysql.Handler) mysql.Handler { +func (ic *InterceptorChain) Option() Option { + return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) { + chainHandler := buildChain(handler, ic.inters) + return e, sm, chainHandler + } +} + +func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler { + // XXX: Mutates |inters| + sort.Slice(inters, func(i, j int) bool { + return inters[i].Priority() < inters[j].Priority() + }) var last Chain = h for i := len(inters) - 1; i >= 0; i-- { filter := inters[i] @@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler { } type Interceptor interface { - // Priority returns the priority of the interceptor. Priority() int @@ -88,7 +99,6 @@ type Interceptor interface { } type Chain interface { - // ComQuery is called when a connection receives a query. // Note the contents of the query slice may change after // the first call to callback. So the Handler should not diff --git a/server/server.go b/server/server.go index 181887e14c..e7f0d6de57 100644 --- a/server/server.go +++ b/server/server.go @@ -37,14 +37,19 @@ type ProtocolListener interface { } // ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given. -type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) - -// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing -// this function will change the protocol listener used when creating all servers. If multiple servers are needed -// with different protocols, then create each server after changing this function. Servers retain the protocol that -// they were created with. -var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { - return mysql.NewListenerWithConfig(cfg) +type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) + +func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { + vtListener, err := mysql.NewListenerWithConfig(listenerCfg) + if err != nil { + return nil, err + } + if cfg.Version != "" { + vtListener.ServerVersion = cfg.Version + } + vtListener.TLSConfig = cfg.TLSConfig + vtListener.RequireSecureTransport = cfg.RequireSecureTransport + return vtListener, nil } type ServerEventListener interface { @@ -114,10 +119,6 @@ func portInUse(hostPort string) bool { } func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) { - for _, option := range cfg.Options { - option(e, sm, handler) - } - if cfg.ConnReadTimeout < 0 { cfg.ConnReadTimeout = 0 } @@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle cfg.MaxConnections = 0 } + for _, opt := range cfg.Options { + e, sm, handler = opt(e, sm, handler) + } + l := cfg.Listener var unixSocketInUse error if l == nil { @@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle ConnReadBufferSize: mysql.DefaultConnBufferSize, AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS, } - protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel) + plf := cfg.ProtocolListenerFactory + if plf == nil { + plf = MySQLProtocolListenerFactory + } + protocolListener, err := plf(cfg, listenerCfg, sel) if err != nil { return nil, err } - if vtListener, ok := protocolListener.(*mysql.Listener); ok { - if cfg.Version != "" { - vtListener.ServerVersion = cfg.Version - } - vtListener.TLSConfig = cfg.TLSConfig - vtListener.RequireSecureTransport = cfg.RequireSecureTransport - } - return &Server{ Listener: protocolListener, handler: handler, diff --git a/server/server_config.go b/server/server_config.go index b7667b9906..4ec35e1cd2 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -23,13 +23,9 @@ import ( "go.opentelemetry.io/otel/trace" gms "github.com/dolthub/go-mysql-server" - sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" ) -// Option is an option to customize server. -type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) - // Server is a MySQL server for SQLe engines. type Server struct { Listener ProtocolListener @@ -38,6 +34,9 @@ type Server struct { Engine *gms.Engine } +// An option to customize the server. +type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler) + // Config for the mysql server. type Config struct { // Protocol for the connection. @@ -82,8 +81,14 @@ type Config struct { // If true, queries will be logged as base64 encoded strings. // If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces. EncodeLoggedQuery bool - // Options add additional options to customize the server. + // Options gets a chance to visit and mutate the GMS *Engine, + // *server.SessionManager and the mysql.Handler as the server + // is being initialized, before the ProtocolListener is + // constructed. Options []Option + // Used to get the ProtocolListener on server start. + // If unset, defaults to MySQLProtocolListenerFactory. + ProtocolListenerFactory ProtocolListenerFunc } func (c Config) NewConfig() (Config, error) {