Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partition: make ExchangePartition follow check constraints(part1) #46021

Merged
merged 9 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 149 additions & 70 deletions ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/tablecodec"
Expand Down Expand Up @@ -2405,6 +2406,12 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo
return ver, errors.Trace(err)
}

ptDbInfo, err := t.GetDatabase(ptSchemaID)
if err != nil {
job.State = model.JobStateCancelled
return ver, errors.Trace(err)
}

nt, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID)
if err != nil {
return ver, errors.Trace(err)
Expand All @@ -2421,7 +2428,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo
return ver, errors.Trace(err)
}

index, partDef, err := getPartitionDef(pt, partName)
_, partDef, err := getPartitionDef(pt, partName)
if err != nil {
return ver, errors.Trace(err)
}
Expand Down Expand Up @@ -2492,7 +2499,15 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo
}

if withValidation {
err = checkExchangePartitionRecordValidation(w, pt, index, ntDbInfo.Name, nt.Name)
ntbl, err := getTable(d.store, job.SchemaID, nt)
if err != nil {
return ver, errors.Trace(err)
}
ptbl, err := getTable(d.store, ptSchemaID, pt)
if err != nil {
return ver, errors.Trace(err)
}
err = checkExchangePartitionRecordValidation(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName)
if err != nil {
job.State = model.JobStateRollingback
return ver, errors.Trace(err)
Expand Down Expand Up @@ -3374,61 +3389,135 @@ func bundlesForExchangeTablePartition(t *meta.Meta, pt *model.TableInfo, newPar
return bundles, nil
}

func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, index int, schemaName, tableName model.CIStr) error {
var sql string
var paramList []interface{}
func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error {
verifyFunc := func(sql string, params ...interface{}) error {
var ctx sessionctx.Context
ctx, err := w.sessPool.Get()
if err != nil {
return errors.Trace(err)
}
defer w.sessPool.Put(ctx)

pi := pt.Partition
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, params...)
if err != nil {
return errors.Trace(err)
}
rowCount := len(rows)
if rowCount != 0 {
return errors.Trace(dbterror.ErrRowDoesNotMatchPartition)
}
// Check warnings!
// Is it possible to check how many rows where checked as well?
return nil
}
genConstraintCondition := func(constraints []*table.Constraint) string {
var buf strings.Builder
buf.WriteString("not (")
for i, cons := range constraints {
if i != 0 {
buf.WriteString(" and ")
}
buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString))
}
buf.WriteString(")")
return buf.String()
}
type CheckConstraintTable interface {
WritableConstraint() []*table.Constraint
}

pt := ptbl.Meta()
index, _, err := getPartitionDef(pt, partitionName)
if err != nil {
return errors.Trace(err)
}

var buf strings.Builder
buf.WriteString("select 1 from %n.%n where ")
paramList := []interface{}{nschemaName, ntbl.Meta().Name.L}
checkNt := true

pi := pt.Partition
switch pi.Type {
case model.PartitionTypeHash:
if pi.Num == 1 {
return nil
checkNt = false
} else {
buf.WriteString("mod(")
buf.WriteString(pi.Expr)
buf.WriteString(", %?) != %?")
Comment on lines +3446 to +3448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests that validates NULLs for Hash partitions? (not blocking this PR, just curious)

paramList = append(paramList, pi.Num, index)
}
var buf strings.Builder
buf.WriteString("select 1 from %n.%n where mod(")
buf.WriteString(pi.Expr)
buf.WriteString(", %?) != %? limit 1")
sql = buf.String()
paramList = append(paramList, schemaName.L, tableName.L, pi.Num, index)
case model.PartitionTypeRange:
// Table has only one partition and has the maximum value
if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) {
return nil
}
// For range expression and range columns
if len(pi.Columns) == 0 {
sql, paramList = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName)
checkNt = false
} else {
sql, paramList = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName)
// For range expression and range columns
if len(pi.Columns) == 0 {
conds, params := buildCheckSQLConditionForRangeExprPartition(pi, index)
buf.WriteString(conds)
paramList = append(paramList, params...)
} else {
conds, params := buildCheckSQLConditionForRangeColumnsPartition(pi, index)
buf.WriteString(conds)
paramList = append(paramList, params...)
}
}
case model.PartitionTypeList:
if len(pi.Columns) == 0 {
sql, paramList = buildCheckSQLForListPartition(pi, index, schemaName, tableName)
conds := buildCheckSQLConditionForListPartition(pi, index)
buf.WriteString(conds)
} else {
sql, paramList = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName)
conds := buildCheckSQLConditionForListColumnsPartition(pi, index)
buf.WriteString(conds)
}
default:
return dbterror.ErrUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O)
}

var ctx sessionctx.Context
ctx, err := w.sessPool.Get()
if err != nil {
return errors.Trace(err)
if variable.EnableCheckConstraint.Load() {
pcc, ok := ptbl.(CheckConstraintTable)
if !ok {
return errors.Errorf("exchange partition process assert table partition failed")
}
pCons := pcc.WritableConstraint()
if len(pCons) > 0 {
if !checkNt {
checkNt = true
} else {
buf.WriteString(" or ")
}
buf.WriteString(genConstraintCondition(pCons))
}
}
defer w.sessPool.Put(ctx)

rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, paramList...)
if err != nil {
return errors.Trace(err)
// Check non-partition table records.
if checkNt {
buf.WriteString(" limit 1")
err = verifyFunc(buf.String(), paramList...)
if err != nil {
return errors.Trace(err)
}
}
rowCount := len(rows)
if rowCount != 0 {
return errors.Trace(dbterror.ErrRowDoesNotMatchPartition)

// Check partition table records.
if variable.EnableCheckConstraint.Load() {
ncc, ok := ntbl.(CheckConstraintTable)
if !ok {
return errors.Errorf("exchange partition process assert table partition failed")
}
nCons := ncc.WritableConstraint()
if len(nCons) > 0 {
buf.Reset()
buf.WriteString("select 1 from %n.%n partition(%n) where ")
buf.WriteString(genConstraintCondition(nCons))
buf.WriteString(" limit 1")
err = verifyFunc(buf.String(), pschemaName, pt.Name.L, partitionName)
if err != nil {
return errors.Trace(err)
}
}
Comment on lines +3504 to +3519
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a note, a future improvement may be to run the validations in parallel.

}
// Check warnings!
// Is it possible to check how many rows where checked as well?
return nil
}

Expand Down Expand Up @@ -3460,53 +3549,47 @@ func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partP
return nil
}

func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) {
func buildCheckSQLConditionForRangeExprPartition(pi *model.PartitionInfo, index int) (string, []interface{}) {
var buf strings.Builder
paramList := make([]interface{}, 0, 4)
paramList := make([]interface{}, 0, 2)
// Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...)
// So we write it to the origin sql string here.
if index == 0 {
buf.WriteString("select 1 from %n.%n where ")
buf.WriteString(pi.Expr)
buf.WriteString(" >= %? limit 1")
paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return buf.String(), paramList
buf.WriteString(" >= %?")
paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
} else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) {
buf.WriteString("select 1 from %n.%n where ")
buf.WriteString(pi.Expr)
buf.WriteString(" < %? limit 1")
paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]))
return buf.String(), paramList
buf.WriteString(" < %?")
paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]))
} else {
buf.WriteString("select 1 from %n.%n where ")
buf.WriteString(pi.Expr)
buf.WriteString(" < %? or ")
buf.WriteString(pi.Expr)
buf.WriteString(" >= %? limit 1")
paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return buf.String(), paramList
buf.WriteString(" >= %?")
paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
}
return buf.String(), paramList
}

func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) {
paramList := make([]interface{}, 0, 6)
func buildCheckSQLConditionForRangeColumnsPartition(pi *model.PartitionInfo, index int) (string, []interface{}) {
paramList := make([]interface{}, 0, 2)
colName := pi.Columns[0].L
if index == 0 {
paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return "select 1 from %n.%n where %n >= %? limit 1", paramList
paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return "%n >= %?", paramList
} else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) {
paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]))
return "select 1 from %n.%n where %n < %? limit 1", paramList
paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]))
return "%n < %?", paramList
} else {
paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return "select 1 from %n.%n where %n < %? or %n >= %? limit 1", paramList
paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0]))
return "%n < %? or %n >= %?", paramList
}
}

func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) {
func buildCheckSQLConditionForListPartition(pi *model.PartitionInfo, index int) string {
var buf strings.Builder
buf.WriteString("select 1 from %n.%n where ")
buf.WriteString(" not (")
buf.WriteString("not (")
for i, inValue := range pi.Definitions[index].InValues {
if i != 0 {
buf.WriteString(" OR ")
Expand All @@ -3521,19 +3604,17 @@ func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaNam
buf.WriteString(fmt.Sprintf("(%s) <=> %s", pi.Expr, val))
}
}
buf.WriteString(") limit 1")
paramList := make([]interface{}, 0, 2)
paramList = append(paramList, schemaName.L, tableName.L)
return buf.String(), paramList
buf.WriteString(")")
return buf.String()
}

func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) {
func buildCheckSQLConditionForListColumnsPartition(pi *model.PartitionInfo, index int) string {
var buf strings.Builder
// How to find a match?
// (row <=> vals1) OR (row <=> vals2)
// How to find a non-matching row:
// NOT ( (row <=> vals1) OR (row <=> vals2) ... )
buf.WriteString("select 1 from %n.%n where not (")
buf.WriteString("not (")
colNames := make([]string, 0, len(pi.Columns))
for i := range pi.Columns {
// TODO: check if there are no proper quoting function for this?
Expand All @@ -3553,10 +3634,8 @@ func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, sc
buf.WriteString(fmt.Sprintf("%s <=> %s", colNames[j], val))
}
}
buf.WriteString(") limit 1")
paramList := make([]interface{}, 0, 2)
paramList = append(paramList, schemaName.L, tableName.L)
return buf.String(), paramList
buf.WriteString(")")
return buf.String()
}

func checkAddPartitionTooManyPartitions(piDefs uint64) error {
Expand Down
Loading