From 2d4ebe5807d1a8a0d6ddf831d2a0169b13704624 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Wed, 5 Mar 2025 12:32:58 -0800 Subject: [PATCH 1/2] server: Get rid of globals for setting a protocol listener factory. Get rid of unused, global-ridden and complicated Interceptor and Option functionality. --- server/extension.go | 183 ---------------------------------------- server/server.go | 39 ++++----- server/server_config.go | 9 +- 3 files changed, 21 insertions(+), 210 deletions(-) delete mode 100644 server/extension.go diff --git a/server/extension.go b/server/extension.go deleted file mode 100644 index 4fb9079c5d..0000000000 --- a/server/extension.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2020-2021 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "context" - "sort" - - "github.com/dolthub/vitess/go/mysql" - "github.com/dolthub/vitess/go/sqltypes" - querypb "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/vitess/go/vt/sqlparser" - ast "github.com/dolthub/vitess/go/vt/sqlparser" - - sqle "github.com/dolthub/go-mysql-server" -) - -func Intercept(h Interceptor) { - inters = append(inters, h) - sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() }) -} - -func WithChain() Option { - return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) { - f := DefaultProtocolListenerFunc - DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { - cfg.Handler = buildChain(cfg.Handler) - return f(cfg, sel) - } - } -} - -var inters []Interceptor - -func buildChain(h mysql.Handler) mysql.Handler { - var last Chain = h - for i := len(inters) - 1; i >= 0; i-- { - filter := inters[i] - next := last - last = &chainInterceptor{i: filter, c: next} - } - return &interceptorHandler{h: h, c: last} -} - -type Interceptor interface { - - // Priority returns the priority of the interceptor. - Priority() int - - // Query is called when a connection receives a query. - // Note the contents of the query slice may change after - // the first call to callback. So the Handler should not - // hang on to the byte slice. - Query(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) error - - // ParsedQuery is called when a connection receives a - // query that has already been parsed. Note the contents - // of the query slice may change after the first call to - // callback. So the Handler should not hang on to the byte - // slice. - ParsedQuery(chain Chain, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(res *sqltypes.Result, more bool) error) error - - // MultiQuery is called when a connection receives a query and the - // client supports MULTI_STATEMENT. It should process the first - // statement in |query| and return the remainder. It will be called - // multiple times until the remainder is |""|. - MultiQuery(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error) - - // Prepare is called when a connection receives a prepared - // statement query. - Prepare(ctx context.Context, chain Chain, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) - - // StmtExecute is called when a connection receives a statement - // execute query. - StmtExecute(ctx context.Context, chain Chain, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error -} - -type Chain interface { - - // ComQuery is called when a connection receives a query. - // Note the contents of the query slice may change after - // the first call to callback. So the Handler should not - // hang on to the byte slice. - ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error - - // ComMultiQuery is called when a connection receives a query and the - // client supports MULTI_STATEMENT. It should process the first - // statement in |query| and return the remainder. It will be called - // multiple times until the remainder is |""|. - ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) - - // ComPrepare is called when a connection receives a prepared - // statement query. - ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) - - // ComStmtExecute is called when a connection receives a statement - // execute query. - ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error -} - -type chainInterceptor struct { - i Interceptor - c Chain -} - -func (ci *chainInterceptor) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { - return ci.i.Query(ctx, ci.c, c, query, callback) -} - -func (ci *chainInterceptor) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { - return ci.i.MultiQuery(ctx, ci.c, c, query, callback) -} - -func (ci *chainInterceptor) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { - return ci.i.Prepare(ctx, ci.c, c, query, prepare) -} - -func (ci *chainInterceptor) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { - return ci.i.StmtExecute(ctx, ci.c, c, prepare, callback) -} - -type interceptorHandler struct { - c Chain - h mysql.Handler -} - -var _ mysql.Handler = (*interceptorHandler)(nil) - -func (ih *interceptorHandler) NewConnection(c *mysql.Conn) { - ih.h.NewConnection(c) -} - -func (ih *interceptorHandler) ConnectionClosed(c *mysql.Conn) { - ih.h.ConnectionClosed(c) -} - -func (ih *interceptorHandler) ConnectionAborted(c *mysql.Conn, reason string) error { - return ih.h.ConnectionAborted(c, reason) -} - -func (ih *interceptorHandler) ComInitDB(c *mysql.Conn, schemaName string) error { - return ih.h.ComInitDB(c, schemaName) -} - -func (ih *interceptorHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { - return ih.c.ComQuery(ctx, c, query, callback) -} - -func (ih *interceptorHandler) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { - return ih.c.ComMultiQuery(ctx, c, query, callback) -} - -func (ih *interceptorHandler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { - return ih.c.ComPrepare(ctx, c, query, prepare) -} - -func (ih *interceptorHandler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { - return ih.c.ComStmtExecute(ctx, c, prepare, callback) -} - -func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 { - return ih.h.WarningCount(c) -} - -func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error { - return ih.h.ComResetConnection(c) -} - -func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) { - return ih.h.ParserOptionsForConnection(c) -} diff --git a/server/server.go b/server/server.go index 181887e14c..3b2b4d0a93 100644 --- a/server/server.go +++ b/server/server.go @@ -37,14 +37,19 @@ type ProtocolListener interface { } // ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given. -type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) - -// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing -// this function will change the protocol listener used when creating all servers. If multiple servers are needed -// with different protocols, then create each server after changing this function. Servers retain the protocol that -// they were created with. -var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { - return mysql.NewListenerWithConfig(cfg) +type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) + +func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) { + vtListener, err := mysql.NewListenerWithConfig(listenerCfg) + if err != nil { + return nil, err + } + if cfg.Version != "" { + vtListener.ServerVersion = cfg.Version + } + vtListener.TLSConfig = cfg.TLSConfig + vtListener.RequireSecureTransport = cfg.RequireSecureTransport + return vtListener, nil } type ServerEventListener interface { @@ -114,10 +119,6 @@ func portInUse(hostPort string) bool { } func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) { - for _, option := range cfg.Options { - option(e, sm, handler) - } - if cfg.ConnReadTimeout < 0 { cfg.ConnReadTimeout = 0 } @@ -156,19 +157,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle ConnReadBufferSize: mysql.DefaultConnBufferSize, AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS, } - protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel) + plf := cfg.ProtocolListenerFactory + if plf == nil { + plf = MySQLProtocolListenerFactory + } + protocolListener, err := plf(cfg, listenerCfg, sel) if err != nil { return nil, err } - if vtListener, ok := protocolListener.(*mysql.Listener); ok { - if cfg.Version != "" { - vtListener.ServerVersion = cfg.Version - } - vtListener.TLSConfig = cfg.TLSConfig - vtListener.RequireSecureTransport = cfg.RequireSecureTransport - } - return &Server{ Listener: protocolListener, handler: handler, diff --git a/server/server_config.go b/server/server_config.go index b7667b9906..0fdb0492b5 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -23,13 +23,9 @@ import ( "go.opentelemetry.io/otel/trace" gms "github.com/dolthub/go-mysql-server" - sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" ) -// Option is an option to customize server. -type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) - // Server is a MySQL server for SQLe engines. type Server struct { Listener ProtocolListener @@ -82,8 +78,9 @@ type Config struct { // If true, queries will be logged as base64 encoded strings. // If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces. EncodeLoggedQuery bool - // Options add additional options to customize the server. - Options []Option + // Used to get the ProtocolListener on server start. + // If unset, defaults to MySQLProtocolListenerFactory. + ProtocolListenerFactory ProtocolListenerFunc } func (c Config) NewConfig() (Config, error) { From 1a65d1c3e944ad0479f9350188c025c06bb2e0d1 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Wed, 5 Mar 2025 14:12:30 -0800 Subject: [PATCH 2/2] server: Revive Option and Interceptor, but without the globals. --- server/extension.go | 193 ++++++++++++++++++++++++++++++++++++++++ server/server.go | 4 + server/server_config.go | 8 ++ 3 files changed, 205 insertions(+) create mode 100644 server/extension.go diff --git a/server/extension.go b/server/extension.go new file mode 100644 index 0000000000..1568efbc87 --- /dev/null +++ b/server/extension.go @@ -0,0 +1,193 @@ +// Copyright 2020-2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "sort" + + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/sqltypes" + querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/sqlparser" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + + sqle "github.com/dolthub/go-mysql-server" +) + +// InterceptorChain allows an integrator to build a chain of +// |Interceptor| instances which will wrap and intercept the server's +// mysql.Handler. +// +// Example usage: +// +// var ic InterceptorChain +// ic.WithInterceptor(metricsInterceptor) +// ic.WithInterceptor(authInterceptor) +// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...) +type InterceptorChain struct { + inters []Interceptor +} + +func (ic *InterceptorChain) WithInterceptor(h Interceptor) { + ic.inters = append(ic.inters, h) +} + +func (ic *InterceptorChain) Option() Option { + return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) { + chainHandler := buildChain(handler, ic.inters) + return e, sm, chainHandler + } +} + +func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler { + // XXX: Mutates |inters| + sort.Slice(inters, func(i, j int) bool { + return inters[i].Priority() < inters[j].Priority() + }) + var last Chain = h + for i := len(inters) - 1; i >= 0; i-- { + filter := inters[i] + next := last + last = &chainInterceptor{i: filter, c: next} + } + return &interceptorHandler{h: h, c: last} +} + +type Interceptor interface { + // Priority returns the priority of the interceptor. + Priority() int + + // Query is called when a connection receives a query. + // Note the contents of the query slice may change after + // the first call to callback. So the Handler should not + // hang on to the byte slice. + Query(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) error + + // ParsedQuery is called when a connection receives a + // query that has already been parsed. Note the contents + // of the query slice may change after the first call to + // callback. So the Handler should not hang on to the byte + // slice. + ParsedQuery(chain Chain, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(res *sqltypes.Result, more bool) error) error + + // MultiQuery is called when a connection receives a query and the + // client supports MULTI_STATEMENT. It should process the first + // statement in |query| and return the remainder. It will be called + // multiple times until the remainder is |""|. + MultiQuery(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error) + + // Prepare is called when a connection receives a prepared + // statement query. + Prepare(ctx context.Context, chain Chain, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) + + // StmtExecute is called when a connection receives a statement + // execute query. + StmtExecute(ctx context.Context, chain Chain, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error +} + +type Chain interface { + // ComQuery is called when a connection receives a query. + // Note the contents of the query slice may change after + // the first call to callback. So the Handler should not + // hang on to the byte slice. + ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error + + // ComMultiQuery is called when a connection receives a query and the + // client supports MULTI_STATEMENT. It should process the first + // statement in |query| and return the remainder. It will be called + // multiple times until the remainder is |""|. + ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) + + // ComPrepare is called when a connection receives a prepared + // statement query. + ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) + + // ComStmtExecute is called when a connection receives a statement + // execute query. + ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error +} + +type chainInterceptor struct { + i Interceptor + c Chain +} + +func (ci *chainInterceptor) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { + return ci.i.Query(ctx, ci.c, c, query, callback) +} + +func (ci *chainInterceptor) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { + return ci.i.MultiQuery(ctx, ci.c, c, query, callback) +} + +func (ci *chainInterceptor) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { + return ci.i.Prepare(ctx, ci.c, c, query, prepare) +} + +func (ci *chainInterceptor) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return ci.i.StmtExecute(ctx, ci.c, c, prepare, callback) +} + +type interceptorHandler struct { + c Chain + h mysql.Handler +} + +var _ mysql.Handler = (*interceptorHandler)(nil) + +func (ih *interceptorHandler) NewConnection(c *mysql.Conn) { + ih.h.NewConnection(c) +} + +func (ih *interceptorHandler) ConnectionClosed(c *mysql.Conn) { + ih.h.ConnectionClosed(c) +} + +func (ih *interceptorHandler) ConnectionAborted(c *mysql.Conn, reason string) error { + return ih.h.ConnectionAborted(c, reason) +} + +func (ih *interceptorHandler) ComInitDB(c *mysql.Conn, schemaName string) error { + return ih.h.ComInitDB(c, schemaName) +} + +func (ih *interceptorHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { + return ih.c.ComQuery(ctx, c, query, callback) +} + +func (ih *interceptorHandler) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { + return ih.c.ComMultiQuery(ctx, c, query, callback) +} + +func (ih *interceptorHandler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { + return ih.c.ComPrepare(ctx, c, query, prepare) +} + +func (ih *interceptorHandler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return ih.c.ComStmtExecute(ctx, c, prepare, callback) +} + +func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 { + return ih.h.WarningCount(c) +} + +func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error { + return ih.h.ComResetConnection(c) +} + +func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) { + return ih.h.ParserOptionsForConnection(c) +} diff --git a/server/server.go b/server/server.go index 3b2b4d0a93..e7f0d6de57 100644 --- a/server/server.go +++ b/server/server.go @@ -129,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle cfg.MaxConnections = 0 } + for _, opt := range cfg.Options { + e, sm, handler = opt(e, sm, handler) + } + l := cfg.Listener var unixSocketInUse error if l == nil { diff --git a/server/server_config.go b/server/server_config.go index 0fdb0492b5..4ec35e1cd2 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -34,6 +34,9 @@ type Server struct { Engine *gms.Engine } +// An option to customize the server. +type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler) + // Config for the mysql server. type Config struct { // Protocol for the connection. @@ -78,6 +81,11 @@ type Config struct { // If true, queries will be logged as base64 encoded strings. // If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces. EncodeLoggedQuery bool + // Options gets a chance to visit and mutate the GMS *Engine, + // *server.SessionManager and the mysql.Handler as the server + // is being initialized, before the ProtocolListener is + // constructed. + Options []Option // Used to get the ProtocolListener on server start. // If unset, defaults to MySQLProtocolListenerFactory. ProtocolListenerFactory ProtocolListenerFunc