Skip to content

Commit

Permalink
*: Implementing RENAME USER (#24413)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjonss authored May 19, 2021
1 parent 1136126 commit 15dfd7b
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 5 deletions.
142 changes: 141 additions & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
err = e.executeAlterUser(x)
case *ast.DropUserStmt:
err = e.executeDropUser(x)
case *ast.RenameUserStmt:
err = e.executeRenameUser(x)
case *ast.SetPwdStmt:
err = e.executeSetPwd(x)
case *ast.KillStmt:
Expand Down Expand Up @@ -1026,6 +1028,123 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error {
return nil
}

// Should cover same internal mysql.* tables as DROP USER, so this function is very similar
func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error {

var failedUser string
sysSession, err := e.getSysSession()
defer e.releaseSysSession(sysSession)
if err != nil {
return err
}
sqlExecutor := sysSession.(sqlexec.SQLExecutor)

if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil {
return err
}

for _, userToUser := range s.UserToUsers {
oldUser, newUser := userToUser.OldUser, userToUser.NewUser
exists, err := userExistsInternal(sqlExecutor, oldUser.Username, oldUser.Hostname)
if err != nil {
return err
}
if !exists {
failedUser = oldUser.String() + " TO " + newUser.String() + " old did not exist"
break
}

exists, err = userExistsInternal(sqlExecutor, newUser.Username, newUser.Hostname)
if err != nil {
return err
}
if exists {
// MySQL reports the old user, even when the issue is the new user.
failedUser = oldUser.String() + " TO " + newUser.String() + " new did exist"
break
}

if err = renameUserHostInSystemTable(sqlExecutor, mysql.UserTable, "User", "Host", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.UserTable + " error"
break
}

// rename privileges from mysql.global_priv
if err = renameUserHostInSystemTable(sqlExecutor, mysql.GlobalPrivTable, "User", "Host", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.GlobalPrivTable + " error"
break
}

// rename privileges from mysql.db
if err = renameUserHostInSystemTable(sqlExecutor, mysql.DBTable, "User", "Host", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.DBTable + " error"
break
}

// rename privileges from mysql.tables_priv
if err = renameUserHostInSystemTable(sqlExecutor, mysql.TablePrivTable, "User", "Host", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.TablePrivTable + " error"
break
}

// rename relationship from mysql.role_edges
if err = renameUserHostInSystemTable(sqlExecutor, mysql.RoleEdgeTable, "TO_USER", "TO_HOST", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.RoleEdgeTable + " (to) error"
break
}

if err = renameUserHostInSystemTable(sqlExecutor, mysql.RoleEdgeTable, "FROM_USER", "FROM_HOST", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.RoleEdgeTable + " (from) error"
break
}

// rename relationship from mysql.default_roles
if err = renameUserHostInSystemTable(sqlExecutor, mysql.DefaultRoleTable, "DEFAULT_ROLE_USER", "DEFAULT_ROLE_HOST", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.DefaultRoleTable + " (default role user) error"
break
}

if err = renameUserHostInSystemTable(sqlExecutor, mysql.DefaultRoleTable, "USER", "HOST", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " " + mysql.DefaultRoleTable + " error"
break
}

// rename relationship from mysql.global_grants
// TODO: add global_grants into the parser
if err = renameUserHostInSystemTable(sqlExecutor, "global_grants", "User", "Host", userToUser); err != nil {
failedUser = oldUser.String() + " TO " + newUser.String() + " mysql.global_grants error"
break
}

//TODO: need update columns_priv once we implement columns_priv functionality.
// When that is added, please refactor both executeRenameUser and executeDropUser to use an array of tables
// to loop over, so it is easier to maintain.
}

if failedUser == "" {
if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil {
return err
}
} else {
if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil {
return err
}
return ErrCannotUser.GenWithStackByArgs("RENAME USER", failedUser)
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
}

func renameUserHostInSystemTable(sqlExecutor sqlexec.SQLExecutor, tableName, usernameColumn, hostColumn string, users *ast.UserToUser) error {
sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, `UPDATE %n.%n SET %n = %?, %n = %? WHERE %n = %? and %n = %?;`,
mysql.SystemDB, tableName,
usernameColumn, users.NewUser.Username, hostColumn, users.NewUser.Hostname,
usernameColumn, users.OldUser.Username, hostColumn, users.OldUser.Hostname)
_, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String())
return err
}

func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error {
// Check privileges.
// Check `CREATE USER` privilege.
Expand Down Expand Up @@ -1181,6 +1300,27 @@ func userExists(ctx sessionctx.Context, name string, host string) (bool, error)
return len(rows) > 0, nil
}

// use the same internal executor to read within the same transaction, otherwise same as userExists
func userExistsInternal(sqlExecutor sqlexec.SQLExecutor, name string, host string) (bool, error) {
sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host)
recordSet, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String())
if err != nil {
return false, err
}
req := recordSet.NewChunk()
err = recordSet.Next(context.TODO(), req)
var rows int = 0
if err == nil {
rows = req.NumRows()
}
errClose := recordSet.Close()
if errClose != nil {
return false, errClose
}
return rows > 0, err
}

func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
var u, h string
if s.User == nil {
Expand Down Expand Up @@ -1389,7 +1529,7 @@ func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) (err error) {

func (e *SimpleExec) autoNewTxn() bool {
switch e.Statement.(type) {
case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt:
case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt, *ast.RenameUserStmt:
return true
}
return false
Expand Down
6 changes: 6 additions & 0 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,12 @@ func (s *testPlanSuite) TestVisitInfo(c *C) {
{mysql.ExtendedPriv, "", "", "", ErrSpecificAccessDenied, false, "BACKUP_ADMIN", true},
},
},
{
sql: "RENAME USER user1 to user1_tmp",
ans: []visitInfo{
{mysql.CreateUserPriv, "", "", "", ErrSpecificAccessDenied, false, "", false},
},
},
}

for _, tt := range tests {
Expand Down
5 changes: 3 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,8 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) {
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, *ast.BRIEStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt:
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt,
*ast.RenameUserStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
return b.buildDDL(ctx, x)
Expand Down Expand Up @@ -2268,7 +2269,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
case *ast.AlterInstanceStmt:
err := ErrSpecificAccessDenied.GenWithStack("SUPER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err)
case *ast.AlterUserStmt:
case *ast.AlterUserStmt, *ast.RenameUserStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
case *ast.GrantStmt:
Expand Down
48 changes: 48 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,54 @@ func (s *testPrivilegeSuite) TestSecurityEnhancedModeStatusVars(c *C) {
}, nil, nil)
}

func (s *testPrivilegeSuite) TestRenameUser(c *C) {
rootSe := newSession(c, s.store, s.dbName)
mustExec(c, rootSe, "DROP USER IF EXISTS 'ru1'@'localhost'")
mustExec(c, rootSe, "DROP USER IF EXISTS ru3")
mustExec(c, rootSe, "DROP USER IF EXISTS ru6@localhost")
mustExec(c, rootSe, "CREATE USER 'ru1'@'localhost'")
mustExec(c, rootSe, "CREATE USER ru3")
mustExec(c, rootSe, "CREATE USER ru6@localhost")
se1 := newSession(c, s.store, s.dbName)
c.Assert(se1.Auth(&auth.UserIdentity{Username: "ru1", Hostname: "localhost"}, nil, nil), IsTrue)

// Check privileges (need CREATE USER)
_, err := se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru4")
c.Assert(err, ErrorMatches, ".*Access denied; you need .at least one of. the CREATE USER privilege.s. for this operation")
mustExec(c, rootSe, "GRANT UPDATE ON mysql.user TO 'ru1'@'localhost'")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru4")
c.Assert(err, ErrorMatches, ".*Access denied; you need .at least one of. the CREATE USER privilege.s. for this operation")
mustExec(c, rootSe, "GRANT CREATE USER ON *.* TO 'ru1'@'localhost'")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru4")
c.Assert(err, IsNil)

// Test a few single rename (both Username and Hostname)
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER 'ru4'@'%' TO 'ru3'@'localhost'")
c.Assert(err, IsNil)
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER 'ru3'@'localhost' TO 'ru3'@'%'")
c.Assert(err, IsNil)
// Including negative tests, i.e. non existing from user and existing to user
_, err = rootSe.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru1@localhost")
c.Assert(err, ErrorMatches, ".*Operation RENAME USER failed for ru3@%.*")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru4 TO ru5@localhost")
c.Assert(err, ErrorMatches, ".*Operation RENAME USER failed for ru4@%.*")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru3")
c.Assert(err, ErrorMatches, ".*Operation RENAME USER failed for ru3@%.*")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru5@localhost, ru4 TO ru7")
c.Assert(err, ErrorMatches, ".*Operation RENAME USER failed for ru4@%.*")
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER ru3 TO ru5@localhost, ru6@localhost TO ru1@localhost")
c.Assert(err, ErrorMatches, ".*Operation RENAME USER failed for ru6@localhost.*")

// Test multi rename, this is a full swap of ru3 and ru6, i.e. need to read its previous state in the same transaction.
_, err = se1.ExecuteInternal(context.Background(), "RENAME USER 'ru3' TO 'ru3_tmp', ru6@localhost TO ru3, 'ru3_tmp' to ru6@localhost")
c.Assert(err, IsNil)

// Cleanup
mustExec(c, rootSe, "DROP USER ru6@localhost")
mustExec(c, rootSe, "DROP USER ru3")
mustExec(c, rootSe, "DROP USER 'ru1'@'localhost'")
}

func (s *testPrivilegeSuite) TestSecurityEnhancedModeSysVars(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("CREATE USER svroot1, svroot2")
Expand Down
3 changes: 2 additions & 1 deletion session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2862,7 +2862,8 @@ func logStmt(execStmt *executor.ExecStmt, vars *variable.SessionVars) {
switch stmt := execStmt.StmtNode.(type) {
case *ast.CreateUserStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.SetPwdStmt, *ast.GrantStmt,
*ast.RevokeStmt, *ast.AlterTableStmt, *ast.CreateDatabaseStmt, *ast.CreateIndexStmt, *ast.CreateTableStmt,
*ast.DropDatabaseStmt, *ast.DropIndexStmt, *ast.DropTableStmt, *ast.RenameTableStmt, *ast.TruncateTableStmt:
*ast.DropDatabaseStmt, *ast.DropIndexStmt, *ast.DropTableStmt, *ast.RenameTableStmt, *ast.TruncateTableStmt,
*ast.RenameUserStmt:
user := vars.User
schemaVersion := vars.GetInfoSchema().SchemaMetaVersion()
if ss, ok := execStmt.StmtNode.(ast.SensitiveStmtNode); ok {
Expand Down
2 changes: 1 addition & 1 deletion session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s
sessVars := se.sessionVars
if meetsErr != nil {
if !sessVars.InTxn() {
logutil.BgLogger().Info("rollbackTxn for ddl/autocommit failed")
logutil.BgLogger().Info("rollbackTxn called due to ddl/autocommit failure")
se.RollbackTxn(ctx)
recordAbortTxnDuration(sessVars)
} else if se.txn.Valid() && se.txn.IsPessimistic() && executor.ErrDeadlock.Equal(meetsErr) {
Expand Down

0 comments on commit 15dfd7b

Please sign in to comment.