Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: implement auth plugin support in the extension framework #53494

Merged
merged 18 commits into from
Jul 3, 2024
Merged
9 changes: 5 additions & 4 deletions pkg/executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,15 @@ func (e *GrantExec) Next(ctx context.Context, _ *chunk.Chunk) error {
// It is required for compatibility with 5.7 but removed from 8.0
// since it results in a massive security issue:
// spelling errors will create users with no passwords.
pwd, ok := user.EncodedPassword()
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
authPlugin := mysql.AuthNativePassword
if user.AuthOpt != nil && user.AuthOpt.AuthPlugin != "" {
authPlugin = user.AuthOpt.AuthPlugin
}
authPluginImpl, _ := e.Ctx().GetExtensions().GetAuthPlugin(authPlugin)
pwd, ok := encodePassword(user, authPluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
_, err := internalSession.GetSQLExecutor().ExecuteInternal(internalCtx,
`INSERT INTO %n.%n (Host, User, authentication_string, plugin) VALUES (%?, %?, %?, %?);`,
mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd, authPlugin)
Expand Down
38 changes: 29 additions & 9 deletions pkg/executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/pingcap/tidb/pkg/executor/internal/querywatch"
executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
Expand Down Expand Up @@ -1164,16 +1165,21 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
return err
}
}
pwd, ok := spec.EncodedPassword()

if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
var pluginImpl *extension.AuthPlugin

switch authPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthTiDBAuthToken, mysql.AuthLDAPSimple, mysql.AuthLDAPSASL:
default:
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
found := false
// If the plugin is not a registered extension auth plugin, return error
if pluginImpl, found = e.Ctx().GetExtensions().GetAuthPlugin(authPlugin); !found {
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
}

pwd, ok := encodePassword(spec, pluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}

recordTokenIssuer := tokenIssuer
Expand Down Expand Up @@ -1607,6 +1613,10 @@ func checkPasswordReusePolicy(ctx context.Context, sqlExecutor sqlexec.SQLExecut
// and the Password Reuse Policy does not take effect.
return nil
}
// Skip password reuse checks for extension auth plugins
if _, ok := sctx.GetExtensions().GetAuthPlugin(authPlugin); ok {
return nil
}
// read password reuse info from mysql.user and global variables.
passwdReuseInfo, err := getUserPasswordLimit(ctx, sqlExecutor, userDetail.user, userDetail.host, userDetail.pLI)
if err != nil {
Expand Down Expand Up @@ -1787,6 +1797,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
if spec.AuthOpt.AuthPlugin == "" {
spec.AuthOpt.AuthPlugin = currentAuthPlugin
}
var authPluginImpl *extension.AuthPlugin
switch spec.AuthOpt.AuthPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthLDAPSimple, mysql.AuthLDAPSASL, "":
authTokenOptionHandler = noNeedAuthTokenOptions
Expand All @@ -1795,7 +1806,10 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
authTokenOptionHandler = RequireAuthTokenOptions
}
default:
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
found := false
if authPluginImpl, found = e.Ctx().GetExtensions().GetAuthPlugin(spec.AuthOpt.AuthPlugin); !found {
return exeerrors.ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
}
// changing the auth method prunes history.
if spec.AuthOpt.AuthPlugin != currentAuthPlugin {
Expand All @@ -1813,7 +1827,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
return err
}
}
pwd, ok := spec.EncodedPassword()
pwd, ok := encodePassword(spec, authPluginImpl)
if !ok {
return errors.Trace(exeerrors.ErrPasswordFormat)
}
Expand Down Expand Up @@ -2480,7 +2494,13 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error
e.Ctx().GetSessionVars().StmtCtx.AppendNote(exeerrors.ErrSetPasswordAuthPlugin.FastGenByArgs(u, h))
pwd = ""
default:
pwd = auth.EncodePassword(s.Password)
if pluginImpl, ok := e.Ctx().GetExtensions().GetAuthPlugin(authplugin); ok {
if pwd, ok = pluginImpl.GenerateAuthString(s.Password); !ok {
return exeerrors.ErrPasswordFormat.GenWithStackByArgs()
}
} else {
pwd = auth.EncodePassword(s.Password)
}
}

// for Support Password Reuse Policy.
Expand Down
22 changes: 22 additions & 0 deletions pkg/executor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ package executor

import (
"strings"

"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/ast"
)

var (
Expand Down Expand Up @@ -97,3 +100,22 @@ func (b *batchRetrieverHelper) nextBatch(retrieveRange func(start, end int) erro
}
return nil
}

// encodePassword encodes the password for the user. It invokes the auth plugin if it is available.
func encodePassword(u *ast.UserSpec, authPlugin *extension.AuthPlugin) (string, bool) {
if u.AuthOpt == nil {
return "", true
}
// If the extension auth plugin is available, use it to encode the password.
if authPlugin != nil {
if u.AuthOpt.ByAuthString {
return authPlugin.GenerateAuthString(u.AuthOpt.AuthString)
}
// If we receive a hash string, validate it first.
if authPlugin.ValidateAuthString(u.AuthOpt.HashString) {
return u.AuthOpt.HashString, true
}
return "", false
}
return u.EncodedPassword()
}
44 changes: 44 additions & 0 deletions pkg/executor/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/planner/core"
"github.com/pingcap/tidb/pkg/types"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -133,3 +136,44 @@ func TestEqualDatumsAsBinary(t *testing.T) {
require.Equal(t, tt.same, res)
}
}

func TestEncodePasswordWithPlugin(t *testing.T) {
hashString := "*3D56A309CD04FA2EEF181462E59011F075C89548"
u := &ast.UserSpec{
User: &auth.UserIdentity{
Username: "test",
},
AuthOpt: &ast.AuthOption{
ByAuthString: false,
AuthString: "xxx",
HashString: hashString,
},
}

p := &extension.AuthPlugin{
ValidateAuthString: func(s string) bool {
return false
},
GenerateAuthString: func(s string) (string, bool) {
if s == "xxx" {
return "xxxxxxx", true
}
return "", false
},
}

u.AuthOpt.ByAuthString = false
_, ok := encodePassword(u, p)
require.False(t, ok)

u.AuthOpt.AuthString = "xxx"
u.AuthOpt.ByAuthString = true
pwd, ok := encodePassword(u, p)
require.True(t, ok)
require.Equal(t, "xxxxxxx", pwd)

u.AuthOpt = nil
pwd, ok = encodePassword(u, p)
require.True(t, ok)
require.Equal(t, "", pwd)
}
129 changes: 129 additions & 0 deletions pkg/extension/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package extension

import (
"crypto/tls"
"slices"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/privilege/conn"
)

// AuthPlugin contains attributes needed for an authentication plugin.
type AuthPlugin struct {
// Name is the name of the auth plugin. It will be registered as a system variable in TiDB which can be used inside the `CREATE USER ... IDENTIFIED WITH 'plugin_name'` statement.
Name string

// RequiredClientSidePlugin is the name of the client-side plugin required by the server-side plugin. It will be used to check if the client has the required plugin installed and require the client to use it if installed.
// The user can require default MySQL plugins such as 'caching_sha2_password' or 'mysql_native_password'.
yzhan1 marked this conversation as resolved.
Show resolved Hide resolved
// If this is empty then `AuthPlugin.Name` is used as the required client-side plugin.
RequiredClientSidePlugin string

// AuthenticateUser is called when a client connects to the server as a user and the server authenticates the user.
// If an error is returned, the login attempt fails, otherwise it succeeds.
// request: The request context for the authentication plugin to authenticate a user
AuthenticateUser func(request AuthenticateRequest) error

// GenerateAuthString is a function for user to implement customized ways to encode the password (e.g. hash/salt/clear-text). The returned string will be stored as the encoded password in the mysql.user table.
// If the input password is considered as invalid, this should return an error.
// pwd: User's input password in CREATE/ALTER USER statements in clear-text
GenerateAuthString func(pwd string) (string, bool)

// ValidateAuthString checks if the password hash stored in the mysql.user table or passed in from `IDENTIFIED AS` is valid.
// This is called when retrieving an existing user to make sure the password stored is valid and not modified and make sure user is passing a valid password hash in `IDENTIFIED AS`.
// pwdHash: hash of the password stored in the internal user table
ValidateAuthString func(pwdHash string) bool

// VerifyPrivilege is called for each user queries, and serves as an extra check for privileges for the user.
// It will only be executed if the user has already been granted the privilege in SQL layer.
// Returns true if user has the requested privilege.
// request: The request context for the authorization plugin to authorize a user's static privilege
VerifyPrivilege func(request VerifyStaticPrivRequest) bool

// VerifyDynamicPrivilege is called for each user queries, and serves as an extra check for dynamic privileges for the user.
// It will only be executed if the user has already been granted the dynamic privilege in SQL layer.
// Returns true if user has the requested privilege.
// request: The request context for the authorization plugin to authorize a user's dynamic privilege
VerifyDynamicPrivilege func(request VerifyDynamicPrivRequest) bool
}

// AuthenticateRequest contains the context for the authentication plugin to authenticate a user.
type AuthenticateRequest struct {
// User The username in the connect attempt
User string
// StoredAuthString The user's auth string stored in mysql.user table
StoredAuthString string
// InputAuthString The user's auth string passed in from the connection attempt in bytes
InputAuthString []byte
// Salt Randomly generated salt for the current connection
Salt []byte
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// AuthConn Interface for the plugin to communicate with the client
AuthConn conn.AuthConn
}

// VerifyStaticPrivRequest contains the context for the plugin to authorize a user's static privilege.
type VerifyStaticPrivRequest struct {
// User The username in the connect attempt
User string
// Host The host that the user is connecting from
Host string
// DB The database to check for privilege
DB string
// Table The table to check for privilege
Table string
// Column The column to check for privilege (currently just a placeholder in TiDB as column-level privilege is not supported by TiDB yet)
Column string
// StaticPriv The privilege type of the SQL statement that will be executed
StaticPriv mysql.PrivilegeType
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// ActiveRoles List of active MySQL roles for the current user
ActiveRoles []*auth.RoleIdentity
}

// VerifyDynamicPrivRequest contains the context for the plugin to authorize a user's dynamic privilege.
type VerifyDynamicPrivRequest struct {
// User The username in the connect attempt
User string
// Host The host that the user is connecting from
Host string
// DynamicPriv the dynamic privilege required by the user's SQL statement
DynamicPriv string
// ConnState The TLS connection state (contains the TLS certificate) if client is using TLS. It will be nil if the client is not using TLS
ConnState *tls.ConnectionState
// ActiveRoles List of active MySQL roles for the current user
ActiveRoles []*auth.RoleIdentity
// WithGrant Whether the statement to be executed is granting the user privilege for executing GRANT statements
WithGrant bool
}

// validateAuthPlugin validates the auth plugin functions and attributes.
func validateAuthPlugin(m *Manifest) error {
pluginNames := make(map[string]bool)
// Validate required functions for the auth plugins
for _, p := range m.authPlugins {
if p.Name == "" {
return errors.Errorf("auth plugin name cannot be empty for %s", p.Name)
}
if pluginNames[p.Name] {
return errors.Errorf("auth plugin name %s has already been registered", p.Name)
}
pluginNames[p.Name] = true
if slices.Contains(mysql.DefaultAuthPlugins, p.Name) {
return errors.Errorf("auth plugin name %s is a reserved name for default auth plugins", p.Name)
}
if p.AuthenticateUser == nil {
return errors.Errorf("auth plugin AuthenticateUser function cannot be nil for %s", p.Name)
}
if p.GenerateAuthString == nil {
return errors.Errorf("auth plugin GenerateAuthString function cannot be nil for %s", p.Name)
}
if p.ValidateAuthString == nil {
return errors.Errorf("auth plugin ValidateAuthString function cannot be nil for %s", p.Name)
}
}
return nil
}
Loading