From c963c7c0867afa671085a0b61eb47c06889bc1a3 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Sat, 12 Aug 2023 22:23:00 +0800 Subject: [PATCH 1/8] make ExchangePartition follow check constraints --- ddl/partition.go | 87 +++++++++++++++++++++ ddl/tests/partition/db_partition_test.go | 96 ++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/ddl/partition.go b/ddl/partition.go index 43210a2e7a50f..120336760c099 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" @@ -2376,6 +2377,12 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } + ptDbInfo, err := t.GetDatabase(ptSchemaID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + nt, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) if err != nil { return ver, errors.Trace(err) @@ -2448,6 +2455,14 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo } } + if variable.EnableCheckConstraint.Load() { + err = verifyExchangePartitionRecordCheckConstraint(w, pt, nt, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + // partition table auto IDs. ptAutoIDs, err := t.GetAutoIDAccessors(ptSchemaID, ptID).Get() if err != nil { @@ -3313,6 +3328,78 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde return nil } +func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.TableInfo, pschemaName, nschemaName, partitionName string) error { + getWriteableConstraintExpr := func(constraints []*model.ConstraintInfo) []string { + writeableConstraintExpr := make([]string, 0, len(constraints)) + for _, con := range constraints { + if !con.Enforced { + continue + } + if con.State == model.StateDeleteOnly || con.State == model.StateDeleteReorganization { + continue + } + writeableConstraintExpr = append(writeableConstraintExpr, con.ExprString) + } + return writeableConstraintExpr + } + + verifyFunc := func(schemaName, tableName, partitionName string, constraintExprs []string) error { + var sql string + paramList := make([]interface{}, 0, 3) + var buf strings.Builder + buf.WriteString("select 1 from %n.%n") + paramList = append(paramList, schemaName, tableName) + if len(partitionName) != 0 { + buf.WriteString(" partition(%n)") + paramList = append(paramList, partitionName) + } + buf.WriteString(" where not (") + for i, con := range constraintExprs { + if i == 0 { + buf.WriteString(con) + } else { + buf.WriteString(fmt.Sprintf(" and %s", con)) + } + } + buf.WriteString(") limit 1") + sql = buf.String() + + logutil.BgLogger().Error(fmt.Sprintf("jiyfsql:%s,args:%v", sql, paramList)) + + var ctx sessionctx.Context + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, paramList...) + if err != nil { + return errors.Trace(err) + } + rowCount := len(rows) + if rowCount != 0 { + // TODO: return other error code + return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) + } + return nil + } + + pCons := getWriteableConstraintExpr(pt.Constraints) + nCons := getWriteableConstraintExpr(nt.Constraints) + if len(pCons) > 0 { + if err := verifyFunc(nschemaName, nt.Name.L, "", pCons); err != nil { + return errors.Trace(err) + } + } + if len(nCons) > 0 { + if err := verifyFunc(pschemaName, pt.Name.L, partitionName, nCons); err != nil { + return errors.Trace(err) + } + } + return nil +} + func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partPPRef *model.PolicyRefInfo) error { partitionPPRef := partPPRef if partitionPPRef == nil { diff --git a/ddl/tests/partition/db_partition_test.go b/ddl/tests/partition/db_partition_test.go index abf22c951aaa1..6925c5cd0bf9c 100644 --- a/ddl/tests/partition/db_partition_test.go +++ b/ddl/tests/partition/db_partition_test.go @@ -3367,6 +3367,102 @@ func TestExchangePartitionValidation(t *testing.T) { tk.MustExec(`insert into t1 values ("2023-08-06","0001")`) } +func TestExchangePartitionCheckConstraint(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec(`create database db_one`) + defer tk.MustExec(`drop database db_one`) + tk.MustExec(`create database db_two`) + defer tk.MustExec(`drop database db_two`) + + ntSql := "create table db_one.nt (a int check (a > 75) not ENFORCED, b int check (b > 50) ENFORCED)" + ptSql := "create table db_two.pt (a int check (a < 75) ENFORCED, b int check (b < 75) ENFORCED) partition by range (a) (partition p0 values less than (50), partition p1 values less than (100) )" + alterSql := "alter table db_two.pt exchange partition p1 with table db_one.nt" + dropSql := "drop table db_one.nt, db_two.pt" + errMsg := "[ddl:1737]Found a row that does not match the partition" + + type record struct { + a int + b int + } + inputs := []struct { + t []record + pt []record + ok bool + }{ + { + t: []record{{60, 60}}, + ok: true, + }, + { + t: []record{{80, 60}}, + ok: false, + }, + { + t: []record{{60, 80}}, + ok: false, + }, + { + t: []record{{80, 80}}, + ok: false, + }, + { + pt: []record{{60, 60}}, + ok: true, + }, + { + pt: []record{{60, 50}}, + ok: false, + }, + // Record in partition p0(less than (50)). + { + pt: []record{{30, 50}}, + ok: true, + }, + { + t: []record{{60, 60}}, + pt: []record{{70, 70}}, + ok: true, + }, + { + t: []record{{60, 60}}, + pt: []record{{70, 70}, {30, 50}}, + ok: true, + }, + { + t: []record{{60, 60}}, + pt: []record{{70, 70}, {30, 50}, {60, 50}}, + ok: false, + }, + { + t: []record{{60, 60}, {60, 80}}, + pt: []record{{70, 70}, {30, 50}}, + ok: false, + }, + } + for _, input := range inputs { + tk.MustExec(`set @@global.tidb_enable_check_constraint = 1`) + tk.MustExec(ntSql) + for _, r := range input.t { + tk.MustExec(fmt.Sprintf("insert into db_one.nt values (%d, %d)", r.a, r.b)) + } + tk.MustExec(ptSql) + for _, r := range input.pt { + tk.MustExec(fmt.Sprintf("insert into db_two.pt values (%d, %d)", r.a, r.b)) + } + if input.ok { + tk.MustExec(alterSql) + tk.MustExec(dropSql) + continue + } + tk.MustContainErrMsg(alterSql, errMsg) + tk.MustExec(`set @@global.tidb_enable_check_constraint = 0`) + tk.MustExec(alterSql) + tk.MustExec(dropSql) + } +} + func TestExchangePartitionPlacementPolicy(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) From 9eaee2257f6c4d703c7fc7e3cdce39bcf3ed8081 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Wed, 16 Aug 2023 21:27:11 +0800 Subject: [PATCH 2/8] fix --- ddl/partition.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index 120336760c099..51b5a6e6fefc6 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -2458,7 +2458,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo if variable.EnableCheckConstraint.Load() { err = verifyExchangePartitionRecordCheckConstraint(w, pt, nt, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) if err != nil { - job.State = model.JobStateCancelled + job.State = model.JobStateRollingback return ver, errors.Trace(err) } } @@ -3364,8 +3364,6 @@ func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.Table buf.WriteString(") limit 1") sql = buf.String() - logutil.BgLogger().Error(fmt.Sprintf("jiyfsql:%s,args:%v", sql, paramList)) - var ctx sessionctx.Context ctx, err := w.sessPool.Get() if err != nil { From 865aa53473b3e3bdeebd0ffc5c86f92eccf12dbf Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Thu, 17 Aug 2023 08:55:29 +0800 Subject: [PATCH 3/8] fix --- ddl/tests/partition/db_partition_test.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ddl/tests/partition/db_partition_test.go b/ddl/tests/partition/db_partition_test.go index 6925c5cd0bf9c..99c4b7df238f2 100644 --- a/ddl/tests/partition/db_partition_test.go +++ b/ddl/tests/partition/db_partition_test.go @@ -3376,10 +3376,10 @@ func TestExchangePartitionCheckConstraint(t *testing.T) { tk.MustExec(`create database db_two`) defer tk.MustExec(`drop database db_two`) - ntSql := "create table db_one.nt (a int check (a > 75) not ENFORCED, b int check (b > 50) ENFORCED)" - ptSql := "create table db_two.pt (a int check (a < 75) ENFORCED, b int check (b < 75) ENFORCED) partition by range (a) (partition p0 values less than (50), partition p1 values less than (100) )" - alterSql := "alter table db_two.pt exchange partition p1 with table db_one.nt" - dropSql := "drop table db_one.nt, db_two.pt" + ntSQL := "create table db_one.nt (a int check (a > 75) not ENFORCED, b int check (b > 50) ENFORCED)" + ptSQL := "create table db_two.pt (a int check (a < 75) ENFORCED, b int check (b < 75) ENFORCED) partition by range (a) (partition p0 values less than (50), partition p1 values less than (100) )" + alterSQL := "alter table db_two.pt exchange partition p1 with table db_one.nt" + dropSQL := "drop table db_one.nt, db_two.pt" errMsg := "[ddl:1737]Found a row that does not match the partition" type record struct { @@ -3443,23 +3443,23 @@ func TestExchangePartitionCheckConstraint(t *testing.T) { } for _, input := range inputs { tk.MustExec(`set @@global.tidb_enable_check_constraint = 1`) - tk.MustExec(ntSql) + tk.MustExec(ntSQL) for _, r := range input.t { tk.MustExec(fmt.Sprintf("insert into db_one.nt values (%d, %d)", r.a, r.b)) } - tk.MustExec(ptSql) + tk.MustExec(ptSQL) for _, r := range input.pt { tk.MustExec(fmt.Sprintf("insert into db_two.pt values (%d, %d)", r.a, r.b)) } if input.ok { - tk.MustExec(alterSql) - tk.MustExec(dropSql) + tk.MustExec(alterSQL) + tk.MustExec(dropSQL) continue } - tk.MustContainErrMsg(alterSql, errMsg) + tk.MustContainErrMsg(alterSQL, errMsg) tk.MustExec(`set @@global.tidb_enable_check_constraint = 0`) - tk.MustExec(alterSql) - tk.MustExec(dropSql) + tk.MustExec(alterSQL) + tk.MustExec(dropSQL) } } From 5e7a1b92eb8365be17613554a5f9b5640bd7a65c Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Mon, 21 Aug 2023 13:05:50 +0800 Subject: [PATCH 4/8] fix --- ddl/partition.go | 20 +++++++++----------- ddl/tests/partition/db_partition_test.go | 22 +++++++++++++++++++--- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index 51b5a6e6fefc6..18e776042a764 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -2453,13 +2453,12 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo job.State = model.JobStateRollingback return ver, errors.Trace(err) } - } - - if variable.EnableCheckConstraint.Load() { - err = verifyExchangePartitionRecordCheckConstraint(w, pt, nt, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) - if err != nil { - job.State = model.JobStateRollingback - return ver, errors.Trace(err) + if variable.EnableCheckConstraint.Load() { + err = verifyExchangePartitionRecordCheckConstraint(w, pt, nt, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) + if err != nil { + job.State = model.JobStateRollingback + return ver, errors.Trace(err) + } } } @@ -3355,11 +3354,10 @@ func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.Table } buf.WriteString(" where not (") for i, con := range constraintExprs { - if i == 0 { - buf.WriteString(con) - } else { - buf.WriteString(fmt.Sprintf(" and %s", con)) + if i != 0 { + buf.WriteString(" and ") } + buf.WriteString(fmt.Sprintf("(%s)", con)) } buf.WriteString(") limit 1") sql = buf.String() diff --git a/ddl/tests/partition/db_partition_test.go b/ddl/tests/partition/db_partition_test.go index 99c4b7df238f2..daf99d3e13429 100644 --- a/ddl/tests/partition/db_partition_test.go +++ b/ddl/tests/partition/db_partition_test.go @@ -3372,12 +3372,10 @@ func TestExchangePartitionCheckConstraint(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec(`create database db_one`) - defer tk.MustExec(`drop database db_one`) tk.MustExec(`create database db_two`) - defer tk.MustExec(`drop database db_two`) ntSQL := "create table db_one.nt (a int check (a > 75) not ENFORCED, b int check (b > 50) ENFORCED)" - ptSQL := "create table db_two.pt (a int check (a < 75) ENFORCED, b int check (b < 75) ENFORCED) partition by range (a) (partition p0 values less than (50), partition p1 values less than (100) )" + ptSQL := "create table db_two.pt (a int check (a < 75) ENFORCED, b int check (b < 75 or b > 100) ENFORCED) partition by range (a) (partition p0 values less than (50), partition p1 values less than (100) )" alterSQL := "alter table db_two.pt exchange partition p1 with table db_one.nt" dropSQL := "drop table db_one.nt, db_two.pt" errMsg := "[ddl:1737]Found a row that does not match the partition" @@ -3407,6 +3405,14 @@ func TestExchangePartitionCheckConstraint(t *testing.T) { t: []record{{80, 80}}, ok: false, }, + { + t: []record{{60, 120}}, + ok: true, + }, + { + t: []record{{80, 120}}, + ok: false, + }, { pt: []record{{60, 60}}, ok: true, @@ -3430,6 +3436,16 @@ func TestExchangePartitionCheckConstraint(t *testing.T) { pt: []record{{70, 70}, {30, 50}}, ok: true, }, + { + t: []record{{60, 60}, {60, 120}}, + pt: []record{{70, 70}, {30, 50}}, + ok: true, + }, + { + t: []record{{60, 60}, {80, 120}}, + pt: []record{{70, 70}, {30, 50}}, + ok: false, + }, { t: []record{{60, 60}}, pt: []record{{70, 70}, {30, 50}, {60, 50}}, From 35e06c9f3e63d8fdad63124aa54d9a4a827371f8 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Mon, 21 Aug 2023 13:55:51 +0800 Subject: [PATCH 5/8] fix --- ddl/partition.go | 51 ++++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index 18e776042a764..175042054faf0 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -2454,7 +2454,15 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } if variable.EnableCheckConstraint.Load() { - err = verifyExchangePartitionRecordCheckConstraint(w, pt, nt, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) + ntbl, err := getTable(d.store, job.SchemaID, nt) + if err != nil { + return ver, errors.Trace(err) + } + ptbl, err := getTable(d.store, ptSchemaID, pt) + if err != nil { + return ver, errors.Trace(err) + } + err = verifyExchangePartitionRecordCheckConstraint(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) if err != nil { job.State = model.JobStateRollingback return ver, errors.Trace(err) @@ -3327,22 +3335,8 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde return nil } -func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.TableInfo, pschemaName, nschemaName, partitionName string) error { - getWriteableConstraintExpr := func(constraints []*model.ConstraintInfo) []string { - writeableConstraintExpr := make([]string, 0, len(constraints)) - for _, con := range constraints { - if !con.Enforced { - continue - } - if con.State == model.StateDeleteOnly || con.State == model.StateDeleteReorganization { - continue - } - writeableConstraintExpr = append(writeableConstraintExpr, con.ExprString) - } - return writeableConstraintExpr - } - - verifyFunc := func(schemaName, tableName, partitionName string, constraintExprs []string) error { +func verifyExchangePartitionRecordCheckConstraint(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error { + verifyFunc := func(schemaName, tableName, partitionName string, constraints []*table.Constraint) error { var sql string paramList := make([]interface{}, 0, 3) var buf strings.Builder @@ -3353,11 +3347,11 @@ func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.Table paramList = append(paramList, partitionName) } buf.WriteString(" where not (") - for i, con := range constraintExprs { + for i, cons := range constraints { if i != 0 { buf.WriteString(" and ") } - buf.WriteString(fmt.Sprintf("(%s)", con)) + buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString)) } buf.WriteString(") limit 1") sql = buf.String() @@ -3381,15 +3375,26 @@ func verifyExchangePartitionRecordCheckConstraint(w *worker, pt, nt *model.Table return nil } - pCons := getWriteableConstraintExpr(pt.Constraints) - nCons := getWriteableConstraintExpr(nt.Constraints) + type CheckConstraintTable interface { + WritableConstraint() []*table.Constraint + } + pcc, ok := ptbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + ncc, ok := ntbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + pCons := pcc.WritableConstraint() + nCons := ncc.WritableConstraint() if len(pCons) > 0 { - if err := verifyFunc(nschemaName, nt.Name.L, "", pCons); err != nil { + if err := verifyFunc(nschemaName, ntbl.Meta().Name.L, "", pCons); err != nil { return errors.Trace(err) } } if len(nCons) > 0 { - if err := verifyFunc(pschemaName, pt.Name.L, partitionName, nCons); err != nil { + if err := verifyFunc(pschemaName, ptbl.Meta().Name.L, partitionName, nCons); err != nil { return errors.Trace(err) } } From f9e53cdda9bcd88b83e44803ba9bf3a8c72a1f30 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Sat, 16 Sep 2023 14:21:58 +0800 Subject: [PATCH 6/8] fix --- ddl/partition.go | 228 +++++++++++++++++++++++++++++------------------ 1 file changed, 143 insertions(+), 85 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index aa4c8530148f1..60ecf29876209 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -2428,7 +2428,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } - index, partDef, err := getPartitionDef(pt, partName) + _, partDef, err := getPartitionDef(pt, partName) if err != nil { return ver, errors.Trace(err) } @@ -2499,25 +2499,18 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo } if withValidation { - err = checkExchangePartitionRecordValidation(w, pt, index, ntDbInfo.Name, nt.Name) + ntbl, err := getTable(d.store, job.SchemaID, nt) if err != nil { - job.State = model.JobStateRollingback return ver, errors.Trace(err) } - if variable.EnableCheckConstraint.Load() { - ntbl, err := getTable(d.store, job.SchemaID, nt) - if err != nil { - return ver, errors.Trace(err) - } - ptbl, err := getTable(d.store, ptSchemaID, pt) - if err != nil { - return ver, errors.Trace(err) - } - err = verifyExchangePartitionRecordCheckConstraint(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) - if err != nil { - job.State = model.JobStateRollingback - return ver, errors.Trace(err) - } + ptbl, err := getTable(d.store, ptSchemaID, pt) + if err != nil { + return ver, errors.Trace(err) + } + err = checkExchangePartitionRecordValidation(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) + if err != nil { + job.State = model.JobStateRollingback + return ver, errors.Trace(err) } } @@ -3396,61 +3389,136 @@ func bundlesForExchangeTablePartition(t *meta.Meta, pt *model.TableInfo, newPar return bundles, nil } -func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, index int, schemaName, tableName model.CIStr) error { - var sql string - var paramList []interface{} +func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error { + verifyFunc := func(sql string, params ...interface{}) error { + var ctx sessionctx.Context + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) - pi := pt.Partition + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, params...) + if err != nil { + return errors.Trace(err) + } + rowCount := len(rows) + if rowCount != 0 { + return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) + } + // Check warnings! + // Is it possible to check how many rows where checked as well? + return nil + } + genConstraintCondition := func(constraints []*table.Constraint) string { + var buf strings.Builder + buf.WriteString("not (") + for i, cons := range constraints { + if i != 0 { + buf.WriteString(" and ") + } + buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString)) + } + buf.WriteString(")") + return buf.String() + } + type CheckConstraintTable interface { + WritableConstraint() []*table.Constraint + } + + pt := ptbl.Meta() + index, _, err := getPartitionDef(pt, partitionName) + if err != nil { + return errors.Trace(err) + } + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + paramList := []interface{}{nschemaName, ntbl.Meta().Name.L} + checkNt := true + + pi := pt.Partition switch pi.Type { case model.PartitionTypeHash: if pi.Num == 1 { - return nil + checkNt = false + } else { + buf.WriteString("mod(") + buf.WriteString(pi.Expr) + buf.WriteString(", %?) != %?") + paramList = append(paramList, pi.Num, index) } - var buf strings.Builder - buf.WriteString("select 1 from %n.%n where mod(") - buf.WriteString(pi.Expr) - buf.WriteString(", %?) != %? limit 1") - sql = buf.String() - paramList = append(paramList, schemaName.L, tableName.L, pi.Num, index) case model.PartitionTypeRange: // Table has only one partition and has the maximum value if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - return nil - } - // For range expression and range columns - if len(pi.Columns) == 0 { - sql, paramList = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName) + checkNt = false } else { - sql, paramList = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName) + // For range expression and range columns + if len(pi.Columns) == 0 { + conds, params := buildCheckSQLConditionForRangeExprPartition(pi, index) + buf.WriteString(conds) + paramList = append(paramList, params...) + } else { + conds, params := buildCheckSQLConditionForRangeColumnsPartition(pi, index) + buf.WriteString(conds) + paramList = append(paramList, params...) + } } case model.PartitionTypeList: if len(pi.Columns) == 0 { - sql, paramList = buildCheckSQLForListPartition(pi, index, schemaName, tableName) + conds := buildCheckSQLConditionForListPartition(pi, index) + buf.WriteString(conds) } else { - sql, paramList = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName) + conds := buildCheckSQLConditionForListColumnsPartition(pi, index) + buf.WriteString(conds) } default: return dbterror.ErrUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O) } - var ctx sessionctx.Context - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) + if variable.EnableCheckConstraint.Load() { + pcc, ok := ptbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + pCons := pcc.WritableConstraint() + if len(pCons) > 0 { + if !checkNt { + checkNt = true + } else { + buf.WriteString(" or ") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, paramList...) - if err != nil { - return errors.Trace(err) + } + buf.WriteString(genConstraintCondition(pCons)) + } } - rowCount := len(rows) - if rowCount != 0 { - return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) + // Check non-partition table records. + if checkNt { + buf.WriteString(" limit 1") + err = verifyFunc(buf.String(), paramList...) + if err != nil { + return errors.Trace(err) + } + } + + // Check partition table records. + if variable.EnableCheckConstraint.Load() { + ncc, ok := ntbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + nCons := ncc.WritableConstraint() + if len(nCons) > 0 { + buf.Reset() + buf.WriteString("select 1 from %n.%n partition(%n) where ") + buf.WriteString(genConstraintCondition(nCons)) + buf.WriteString(" limit 1") + err = verifyFunc(buf.String(), pschemaName, pt.Name.L, partitionName) + if err != nil { + return errors.Trace(err) + } + } } - // Check warnings! - // Is it possible to check how many rows where checked as well? return nil } @@ -3548,53 +3616,47 @@ func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partP return nil } -func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { +func buildCheckSQLConditionForRangeExprPartition(pi *model.PartitionInfo, index int) (string, []interface{}) { var buf strings.Builder - paramList := make([]interface{}, 0, 4) + paramList := make([]interface{}, 0, 2) // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) // So we write it to the origin sql string here. if index == 0 { - buf.WriteString("select 1 from %n.%n where ") buf.WriteString(pi.Expr) - buf.WriteString(" >= %? limit 1") - paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return buf.String(), paramList + buf.WriteString(" >= %?") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - buf.WriteString("select 1 from %n.%n where ") buf.WriteString(pi.Expr) - buf.WriteString(" < %? limit 1") - paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) - return buf.String(), paramList + buf.WriteString(" < %?") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) } else { - buf.WriteString("select 1 from %n.%n where ") buf.WriteString(pi.Expr) buf.WriteString(" < %? or ") buf.WriteString(pi.Expr) - buf.WriteString(" >= %? limit 1") - paramList = append(paramList, schemaName.L, tableName.L, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return buf.String(), paramList + buf.WriteString(" >= %?") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) } + return buf.String(), paramList } -func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { - paramList := make([]interface{}, 0, 6) +func buildCheckSQLConditionForRangeColumnsPartition(pi *model.PartitionInfo, index int) (string, []interface{}) { + paramList := make([]interface{}, 0, 2) colName := pi.Columns[0].L if index == 0 { - paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return "select 1 from %n.%n where %n >= %? limit 1", paramList + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + return "%n >= %?", paramList } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) - return "select 1 from %n.%n where %n < %? limit 1", paramList + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) + return "%n < %?", paramList } else { - paramList = append(paramList, schemaName.L, tableName.L, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return "select 1 from %n.%n where %n < %? or %n >= %? limit 1", paramList + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + return "%n < %? or %n >= %?", paramList } } -func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { +func buildCheckSQLConditionForListPartition(pi *model.PartitionInfo, index int) string { var buf strings.Builder - buf.WriteString("select 1 from %n.%n where ") - buf.WriteString(" not (") + buf.WriteString("not (") for i, inValue := range pi.Definitions[index].InValues { if i != 0 { buf.WriteString(" OR ") @@ -3609,19 +3671,17 @@ func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaNam buf.WriteString(fmt.Sprintf("(%s) <=> %s", pi.Expr, val)) } } - buf.WriteString(") limit 1") - paramList := make([]interface{}, 0, 2) - paramList = append(paramList, schemaName.L, tableName.L) - return buf.String(), paramList + buf.WriteString(")") + return buf.String() } -func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { +func buildCheckSQLConditionForListColumnsPartition(pi *model.PartitionInfo, index int) string { var buf strings.Builder // How to find a match? // (row <=> vals1) OR (row <=> vals2) // How to find a non-matching row: // NOT ( (row <=> vals1) OR (row <=> vals2) ... ) - buf.WriteString("select 1 from %n.%n where not (") + buf.WriteString("not (") colNames := make([]string, 0, len(pi.Columns)) for i := range pi.Columns { // TODO: check if there are no proper quoting function for this? @@ -3641,10 +3701,8 @@ func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, sc buf.WriteString(fmt.Sprintf("%s <=> %s", colNames[j], val)) } } - buf.WriteString(") limit 1") - paramList := make([]interface{}, 0, 2) - paramList = append(paramList, schemaName.L, tableName.L) - return buf.String(), paramList + buf.WriteString(")") + return buf.String() } func checkAddPartitionTooManyPartitions(piDefs uint64) error { From d47ff82e3460715e0c50e2f5b9342da587e5e348 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Sat, 16 Sep 2023 14:24:24 +0800 Subject: [PATCH 7/8] fix --- ddl/partition.go | 66 ------------------------------------------------ 1 file changed, 66 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index 60ecf29876209..5550cb2a562ee 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -3522,72 +3522,6 @@ func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, p return nil } -func verifyExchangePartitionRecordCheckConstraint(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error { - verifyFunc := func(schemaName, tableName, partitionName string, constraints []*table.Constraint) error { - var sql string - paramList := make([]interface{}, 0, 3) - var buf strings.Builder - buf.WriteString("select 1 from %n.%n") - paramList = append(paramList, schemaName, tableName) - if len(partitionName) != 0 { - buf.WriteString(" partition(%n)") - paramList = append(paramList, partitionName) - } - buf.WriteString(" where not (") - for i, cons := range constraints { - if i != 0 { - buf.WriteString(" and ") - } - buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString)) - } - buf.WriteString(") limit 1") - sql = buf.String() - - var ctx sessionctx.Context - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ctx, nil, sql, paramList...) - if err != nil { - return errors.Trace(err) - } - rowCount := len(rows) - if rowCount != 0 { - // TODO: return other error code - return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) - } - return nil - } - - type CheckConstraintTable interface { - WritableConstraint() []*table.Constraint - } - pcc, ok := ptbl.(CheckConstraintTable) - if !ok { - return errors.Errorf("exchange partition process assert table partition failed") - } - ncc, ok := ntbl.(CheckConstraintTable) - if !ok { - return errors.Errorf("exchange partition process assert table partition failed") - } - pCons := pcc.WritableConstraint() - nCons := ncc.WritableConstraint() - if len(pCons) > 0 { - if err := verifyFunc(nschemaName, ntbl.Meta().Name.L, "", pCons); err != nil { - return errors.Trace(err) - } - } - if len(nCons) > 0 { - if err := verifyFunc(pschemaName, ptbl.Meta().Name.L, partitionName, nCons); err != nil { - return errors.Trace(err) - } - } - return nil -} - func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partPPRef *model.PolicyRefInfo) error { partitionPPRef := partPPRef if partitionPPRef == nil { From c57a3a25ddd0698a8c514e5535e281a8bc8808f4 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Sat, 16 Sep 2023 14:30:39 +0800 Subject: [PATCH 8/8] fix --- ddl/partition.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ddl/partition.go b/ddl/partition.go index 5550cb2a562ee..1d12664dcfa23 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -3487,7 +3487,6 @@ func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, p checkNt = true } else { buf.WriteString(" or ") - } buf.WriteString(genConstraintCondition(pCons)) }