diff --git a/session/nontransactional.go b/session/nontransactional.go index cde9941438c74..303513b9bb8e0 100644 --- a/session/nontransactional.go +++ b/session/nontransactional.go @@ -173,13 +173,15 @@ func splitDeleteWorker(ctx context.Context, jobs []job, stmt *ast.NonTransaction for _, col := range tableName.TableInfo.Columns { if col.Name.L == stmt.ShardColumn.Name.L { shardColumnRefer = &ast.ResultField{ - Column: col, - Table: tableName.TableInfo, + Column: col, + Table: tableName.TableInfo, + DBName: tableName.Schema, + TableName: tableName, } shardColumnType = col.FieldType } } - if shardColumnRefer == nil && stmt.ShardColumn.Name.O != "_tidb_rowid" { + if shardColumnRefer == nil && stmt.ShardColumn.Name.L != model.ExtraHandleName.L { return nil, errors.New("Non-transactional delete, column not found") } @@ -204,6 +206,16 @@ func splitDeleteWorker(ctx context.Context, jobs []job, stmt *ast.NonTransaction default: } + // _tidb_rowid + if shardColumnRefer == nil { + shardColumnType = *types.NewFieldType(mysql.TypeLonglong) + shardColumnRefer = &ast.ResultField{ + Column: model.NewExtraHandleColInfo(), + Table: tableName.TableInfo, + DBName: tableName.Schema, + TableName: tableName, + } + } stmtBuildInfo := statementBuildInfo{ stmt: stmt, shardColumnType: shardColumnType, @@ -495,10 +507,10 @@ func selectShardColumn(stmt *ast.NonTransactionalDeleteStmt, se Session, tableNa if index.Primary { if len(index.Columns) == 1 { shardColumnInfo = tableInfo.Columns[index.Columns[0].Offset] - } else { - // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column - return false, nil, errors.New("Non-transactional delete, the clustered index contains multiple columns. Please specify a shard column") + break } + // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column + return false, nil, errors.New("Non-transactional delete, the clustered index contains multiple columns. Please specify a shard column") } } if shardColumnInfo == nil { @@ -506,20 +518,25 @@ func selectShardColumn(stmt *ast.NonTransactionalDeleteStmt, se Session, tableNa } } - shardColumnName := "_tidb_rowid" + shardColumnName := model.ExtraHandleName.L if shardColumnInfo != nil { shardColumnName = shardColumnInfo.Name.L } + + outputTableName := tableName.Name + if tableAsName.L != "" { + outputTableName = tableAsName + } stmt.ShardColumn = &ast.ColumnName{ Schema: tableName.Schema, - Table: tableAsName, // so that table alias works + Table: outputTableName, // so that table alias works Name: model.NewCIStr(shardColumnName), } return true, shardColumnInfo, nil } shardColumnName = stmt.ShardColumn.Name.L - if shardColumnName == "_tidb_rowid" && !tableInfo.HasClusteredIndex() { + if shardColumnName == model.ExtraHandleName.L && !tableInfo.HasClusteredIndex() { return true, nil, nil } diff --git a/session/nontransactional_test.go b/session/nontransactional_test.go index 7c84ed5c69933..bf5129907c921 100644 --- a/session/nontransactional_test.go +++ b/session/nontransactional_test.go @@ -132,6 +132,20 @@ func TestNonTransactionalDeleteSplitOnTiDBRowID(t *testing.T) { for i := 0; i < 100; i++ { tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) } + // auto select results in full col name + tk.MustQuery("batch limit 3 dry run delete from t").Check(testkit.Rows( + "DELETE FROM `test`.`t` WHERE `test`.`t`.`_tidb_rowid` BETWEEN 1 AND 3", + "DELETE FROM `test`.`t` WHERE `test`.`t`.`_tidb_rowid` BETWEEN 100 AND 100", + )) + // otherwise the name is the same as what is given + tk.MustQuery("batch on _tidb_rowid limit 3 dry run delete from t").Check(testkit.Rows( + "DELETE FROM `test`.`t` WHERE `_tidb_rowid` BETWEEN 1 AND 3", + "DELETE FROM `test`.`t` WHERE `_tidb_rowid` BETWEEN 100 AND 100", + )) + tk.MustQuery("batch on t._tidb_rowid limit 3 dry run delete from t").Check(testkit.Rows( + "DELETE FROM `test`.`t` WHERE `t`.`_tidb_rowid` BETWEEN 1 AND 3", + "DELETE FROM `test`.`t` WHERE `t`.`_tidb_rowid` BETWEEN 100 AND 100", + )) tk.MustExec("batch on _tidb_rowid limit 3 delete from t") tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) }