Skip to content

Commit

Permalink
server, sessionctx: support token-based authentication (pingcap#36152)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Jul 14, 2022
1 parent 6af1f4f commit 9a2ed52
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 24 deletions.
21 changes: 21 additions & 0 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics/handle"
"github.com/pingcap/tidb/telemetry"
Expand Down Expand Up @@ -1634,6 +1635,26 @@ func (do *Domain) NotifyUpdateSysVarCache() {
}
}

// LoadSigningCertLoop loads the signing cert periodically to make sure it's fresh new.
func (do *Domain) LoadSigningCertLoop() {
do.wg.Add(1)
go func() {
defer func() {
do.wg.Done()
logutil.BgLogger().Debug("loadSigningCertLoop exited.")
util.Recover(metrics.LabelDomain, "LoadSigningCertLoop", nil, false)
}()
for {
select {
case <-time.After(sessionstates.LoadCertInterval):
sessionstates.ReloadSigningCert()
case <-do.exit:
return
}
}
}()
}

// ServerID gets serverID.
func (do *Domain) ServerID() uint64 {
return atomic.LoadUint64(&do.serverID)
Expand Down
15 changes: 13 additions & 2 deletions executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -2024,8 +2024,19 @@ func (e *ShowExec) fetchShowSessionStates(ctx context.Context) error {
if err = stateJSON.UnmarshalJSON(stateBytes); err != nil {
return err
}
// This will be implemented in future PRs.
tokenBytes, err := gjson.Marshal("")
// session token
var token *sessionstates.SessionToken
// In testing, user may be nil.
if user := e.ctx.GetSessionVars().User; user != nil {
// The token may be leaked without secure transport, so we enforce secure transport (TLS or Unix Socket).
if !e.ctx.GetSessionVars().ConnectionInfo.IsSecureTransport() {
return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("the token must be queried with secure transport")
}
if token, err = sessionstates.CreateSessionToken(user.Username); err != nil {
return err
}
}
tokenBytes, err := gjson.Marshal(token)
if err != nil {
return errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions parser/mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ const (
AuthNativePassword = "mysql_native_password" // #nosec G101
AuthCachingSha2Password = "caching_sha2_password" // #nosec G101
AuthSocket = "auth_socket"
AuthTiDBSessionToken = "tidb_session_token"
)

// MySQL database and tables.
Expand Down
22 changes: 18 additions & 4 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
Expand Down Expand Up @@ -731,6 +732,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
}
case mysql.AuthNativePassword:
case mysql.AuthSocket:
case mysql.AuthTiDBSessionToken:
default:
return errors.New("Unknown auth plugin")
}
Expand All @@ -757,6 +759,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo
case mysql.AuthCachingSha2Password:
case mysql.AuthNativePassword:
case mysql.AuthSocket:
case mysql.AuthTiDBSessionToken:
default:
logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin))
}
Expand Down Expand Up @@ -858,7 +861,16 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
return errAccessDeniedNoPassword.FastGenByArgs(cc.user, host)
}

if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) {
userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host}
if authPlugin == mysql.AuthTiDBSessionToken {
if !cc.ctx.AuthWithoutVerification(userIdentity) {
return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
if err = sessionstates.ValidateSessionToken(authData, cc.user); err != nil {
logutil.BgLogger().Warn("verify session token failed", zap.String("username", cc.user), zap.Error(err))
return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
} else if !cc.ctx.Auth(userIdentity, authData, cc.salt) {
return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
cc.ctx.SetPort(port)
Expand All @@ -883,6 +895,10 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon
}

authData := resp.Auth
// tidb_session_token is always permitted and skips stored user plugin.
if resp.AuthPlugin == mysql.AuthTiDBSessionToken {
return authData, nil
}
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
Expand Down Expand Up @@ -2411,9 +2427,7 @@ func (cc *clientConn) handleResetConnection(ctx context.Context) error {
}

func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
if plugin.IsEnable(plugin.Audit) {
cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo()
}
cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo()

err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
Expand Down
88 changes: 88 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"path/filepath"
"strings"
"sync/atomic"
"testing"
Expand All @@ -30,13 +33,15 @@ import (
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/store/mockstore/unistore"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/testkit/external"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1258,6 +1263,89 @@ func TestAuthPlugin2(t *testing.T) {

}

func TestAuthTokenPlugin(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()

cfg := newTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
srv, err := NewServer(cfg, drv)
require.NoError(t, err)
ctx := context.Background()

// create the cert
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "test1_cert.pem")
keyPath := filepath.Join(tempDir, "test1_key.pem")
err = util.CreateCertificates(certPath, keyPath, 4096, x509.RSA, x509.UnknownSignatureAlgorithm)
require.NoError(t, err)

tk := testkit.NewTestKit(t, store)
tk.MustExec("CREATE USER auth_session_token")
tk.MustExec("CREATE USER another_user")
tk.MustExec(fmt.Sprintf("set global %s='%s'", variable.TiDBAuthSigningCert, certPath))
tk.MustExec(fmt.Sprintf("set global %s='%s'", variable.TiDBAuthSigningKey, keyPath))

tc, err := drv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "", nil)
require.NoError(t, err)
cc := &clientConn{
connectionID: 1,
alloc: arena.NewAllocator(1024),
chunkAlloc: chunk.NewAllocator(),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
pkt: &packetIO{
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "auth_session_token",
}
cc.setCtx(tc)
// create a token without TLS
tk1 := testkit.NewTestKitWithSession(t, store, tc.Session)
tc.Session.GetSessionVars().ConnectionInfo = cc.connectInfo()
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil)
err = tk1.QueryToErr("show session_states")
require.ErrorContains(t, err, "secure transport")

// create a token with TLS
cc.tlsConn = &tls.Conn{}
tc.Session.GetSessionVars().ConnectionInfo = cc.connectInfo()
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil)
tk1.MustQuery("show session_states")

// create a token with UnixSocket
cc.tlsConn = nil
cc.isUnixSocket = true
tc.Session.GetSessionVars().ConnectionInfo = cc.connectInfo()
rows := tk1.MustQuery("show session_states").Rows()
tokenBytes := []byte(rows[0][1].(string))

// auth with the token
resp := handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthTiDBSessionToken,
Auth: tokenBytes,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
require.NoError(t, err)

// wrong token should fail
tokenBytes[0] ^= 0xff
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
require.ErrorContains(t, err, "Access denied")
tokenBytes[0] ^= 0xff

// using the token to auth with another user should fail
cc.user = "another_user"
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
require.ErrorContains(t, err, "Access denied")
}

func TestMaxAllowedPacket(t *testing.T) {
// Test cases from issue 31422: https://github.com/pingcap/tidb/issues/31422
// The string "SELECT length('') as len;" has 25 chars,
Expand Down
14 changes: 4 additions & 10 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,7 @@ func (s *Server) onConn(conn *clientConn) {
metrics.ConnGauge.Set(float64(connections))

sessionVars := conn.ctx.GetSessionVars()
if plugin.IsEnable(plugin.Audit) {
sessionVars.ConnectionInfo = conn.connectInfo()
}
sessionVars.ConnectionInfo = conn.connectInfo()
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
Expand All @@ -565,10 +563,6 @@ func (s *Server) onConn(conn *clientConn) {
conn.Run(ctx)

err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
// Audit plugin may be disabled before a conn is created, leading no connectionInfo in sessionVars.
if sessionVars.ConnectionInfo == nil {
sessionVars.ConnectionInfo = conn.connectInfo()
}
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond)
Expand All @@ -585,11 +579,11 @@ func (s *Server) onConn(conn *clientConn) {
}

func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
connType := "Socket"
connType := variable.ConnTypeSocket
if cc.isUnixSocket {
connType = "UnixSocket"
connType = variable.ConnTypeUnixSocket
} else if cc.tlsConn != nil {
connType = "SSL/TLS"
connType = variable.ConnTypeTLS
}
connInfo := &variable.ConnectionInfo{
ConnectionID: cc.connectionID,
Expand Down
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2804,6 +2804,7 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) {
}

dom.DumpFileGcCheckerLoop()
dom.LoadSigningCertLoop()

if raw, ok := store.(kv.EtcdBackend); ok {
err = raw.StartGCWorker()
Expand Down
8 changes: 4 additions & 4 deletions sessionctx/sessionstates/session_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ import (
const (
// A token needs a lifetime to avoid brute force attack.
tokenLifetime = time.Minute
// Reload the certificate periodically because it may be rotated.
loadCertInterval = 10 * time.Minute
// LoadCertInterval is the interval of reloading the certificate. The certificate should be rotated periodically.
LoadCertInterval = 10 * time.Minute
// After a certificate is replaced, it's still valid for oldCertValidTime.
// oldCertValidTime must be a little longer than loadCertInterval, because the previous server may
// oldCertValidTime must be a little longer than LoadCertInterval, because the previous server may
// sign with the old cert but the new server checks with the new cert.
// - server A loads the old cert at 00:00:00.
// - the cert is rotated at 00:00:01 on all servers.
Expand Down Expand Up @@ -224,7 +224,7 @@ func (sc *signingCert) loadCert() error {
newCerts = append(newCerts, &certInfo{
cert: cert,
privKey: tlsCert.PrivateKey,
expireTime: now.Add(loadCertInterval + oldCertValidTime),
expireTime: now.Add(LoadCertInterval + oldCertValidTime),
})
for i := 0; i < len(sc.certs); i++ {
// Discard the certs that are already expired.
Expand Down
4 changes: 2 additions & 2 deletions sessionctx/sessionstates/session_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestCertExpire(t *testing.T) {
err = ValidateSessionToken(tokenBytes, "test_user")
require.NoError(t, err)
// the old cert expires and the original token is invalid
timeOffset := uint64(loadCertInterval)
timeOffset := uint64(LoadCertInterval)
require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset)))
ReloadSigningCert()
timeOffset += uint64(oldCertValidTime + time.Minute)
Expand All @@ -197,7 +197,7 @@ func TestCertExpire(t *testing.T) {
require.NoError(t, err)
// the cert is rotated but is still valid
createRSACert(t, certPath2, keyPath2)
timeOffset += uint64(loadCertInterval)
timeOffset += uint64(LoadCertInterval)
require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset)))
ReloadSigningCert()
err = ValidateSessionToken(tokenBytes, "test_user")
Expand Down
22 changes: 20 additions & 2 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ type SessionVars struct {
// Killed is a flag to indicate that this query is killed.
Killed uint32

// ConnectionInfo indicates current connection info used by current session, only be lazy assigned by plugin.
// ConnectionInfo indicates current connection info used by current session.
ConnectionInfo *ConnectionInfo

// NoopFuncsMode allows OFF/ON/WARN values as 0/1/2.
Expand Down Expand Up @@ -1296,7 +1296,7 @@ func (pps PreparedParams) String() string {
return " [arguments: " + types.DatumsToStrNoErr(pps) + "]"
}

// ConnectionInfo present connection used by audit.
// ConnectionInfo presents the connection information, which is mainly used by audit logs.
type ConnectionInfo struct {
ConnectionID uint64
ConnectionType string
Expand All @@ -1316,6 +1316,24 @@ type ConnectionInfo struct {
DB string
}

const (
// ConnTypeSocket indicates socket without TLS.
ConnTypeSocket string = "Socket"
// ConnTypeUnixSocket indicates Unix Socket.
ConnTypeUnixSocket string = "UnixSocket"
// ConnTypeTLS indicates socket with TLS.
ConnTypeTLS string = "SSL/TLS"
)

// IsSecureTransport checks whether the connection is secure.
func (connInfo *ConnectionInfo) IsSecureTransport() bool {
switch connInfo.ConnectionType {
case ConnTypeUnixSocket, ConnTypeTLS:
return true
}
return false
}

// NewSessionVars creates a session vars object.
func NewSessionVars() *SessionVars {
vars := &SessionVars{
Expand Down
9 changes: 9 additions & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
Expand Down Expand Up @@ -813,6 +814,14 @@ var defaultSysVars = []*SysVar{
}, GetGlobal: func(s *SessionVars) (string, error) {
return BoolToOnOff(EnableNoopVariables.Load()), nil
}},
{Scope: ScopeGlobal, Name: TiDBAuthSigningCert, Value: "", Type: TypeStr, SetGlobal: func(s *SessionVars, val string) error {
sessionstates.SetCertPath(val)
return nil
}},
{Scope: ScopeGlobal, Name: TiDBAuthSigningKey, Value: "", Type: TypeStr, SetGlobal: func(s *SessionVars, val string) error {
sessionstates.SetKeyPath(val)
return nil
}},

/* The system variables below have GLOBAL and SESSION scope */
{Scope: ScopeGlobal | ScopeSession, Name: SQLSelectLimit, Value: "18446744073709551615", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64, SetSession: func(s *SessionVars, val string) error {
Expand Down
Loading

0 comments on commit 9a2ed52

Please sign in to comment.