Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Cast all update expressions in verify queries #14555

Merged
merged 6 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,15 @@ func TestFkQueries(t *testing.T) {
"update fk_t11 set col = id where id in (1, 5)",
},
},
{
name: "Update on child to 0 when parent has -0",
queries: []string{
"insert into fk_t15 (id, col) values (2, '-0')",
"insert /*+ SET_VAR(foreign_key_checks=0) */ into fk_t16 (id, col) values (3, '5'), (4, '-5')",
"update fk_t16 set col = col * (col - (col)) where id = 3",
"update fk_t16 set col = col * (col - (col)) where id = 4",
},
},
}

for _, testcase := range testcases {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 4 additions & 19 deletions go/vt/vtgate/engine/fk_cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// FkChild contains the Child Primitive to be executed collecting the values from the Selection Primitive using the column indexes.
Expand All @@ -42,12 +41,10 @@ type FkChild struct {

// NonLiteralUpdateInfo stores the information required to process non-literal update queries.
// It stores 4 information-
// 1. ExprCol- The index of the column being updated in the select query.
// 2. CompExprCol- The index of the comparison expression in the select query to know if the row value is actually being changed or not.
// 3. UpdateExprCol- The index of the updated expression in the select query.
// 4. UpdateExprBvName- The bind variable name to store the updated expression into.
// 1. CompExprCol- The index of the comparison expression in the select query to know if the row value is actually being changed or not.
// 2. UpdateExprCol- The index of the updated expression in the select query.
// 3. UpdateExprBvName- The bind variable name to store the updated expression into.
type NonLiteralUpdateInfo struct {
ExprCol int
CompExprCol int
UpdateExprCol int
UpdateExprBvName string
Expand Down Expand Up @@ -188,19 +185,7 @@ func (fkc *FkCascade) executeNonLiteralExprFkChild(ctx context.Context, vcursor

// Next, we need to copy the updated expressions value into the bind variables map.
for _, info := range child.NonLiteralInfo {
// Type case the value to that of the column that we are updating.
// This is required for example when we receive an updated float value of -0, but
// the column being updated is a varchar column, then if we don't coerce the value of -0 to
// varchar, MySQL ends up setting it to '0' instead of '-0'.
finalVal := row[info.UpdateExprCol]
if !finalVal.IsNull() {
var err error
finalVal, err = evalengine.CoerceTo(finalVal, selectionRes.Fields[info.ExprCol].Type)
if err != nil {
return err
}
}
bindVars[info.UpdateExprBvName] = sqltypes.ValueBindVariable(finalVal)
bindVars[info.UpdateExprBvName] = sqltypes.ValueBindVariable(row[info.UpdateExprCol])
}
_, err := vcursor.ExecutePrimitive(ctx, child.Exec, bindVars, wantfields)
if err != nil {
Expand Down
84 changes: 62 additions & 22 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"slices"
"strings"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/sysvars"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -217,7 +219,7 @@ func buildFkOperator(ctx *plancontext.PlanningContext, updOp ops.Operator, updCl
return nil, err
}

return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks)
return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks, updatedTable)
}

// splitChildFks splits the child foreign keys into restrict and cascade list as restrict is handled through Verify operator and cascade is handled through Cascade operator.
Expand Down Expand Up @@ -268,7 +270,7 @@ func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator,
for _, updExpr := range ue {
// We add the expression and a comparison expression to the SELECT exprssion while storing their offsets.
var info engine.NonLiteralUpdateInfo
info, selectExprs = addNonLiteralUpdExprToSelect(ctx, updExpr, selectExprs)
info, selectExprs = addNonLiteralUpdExprToSelect(ctx, updatedTable, updExpr, selectExprs)
nonLiteralUpdateInfo = append(nonLiteralUpdateInfo, info)
}
}
Expand Down Expand Up @@ -329,25 +331,22 @@ func addColumns(ctx *plancontext.PlanningContext, columns sqlparser.Columns, exp
// For an update query having non-literal updates, we add the updated expression and a comparison expression to the select query.
// For example, for a query like `update fk_table set col = id * 100 + 1`
// We would add the expression `id * 100 + 1` and the comparison expression `col <=> id * 100 + 1` to the select query.
func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updExpr *sqlparser.UpdateExpr, exprs []sqlparser.SelectExpr) (engine.NonLiteralUpdateInfo, []sqlparser.SelectExpr) {
func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr, exprs []sqlparser.SelectExpr) (engine.NonLiteralUpdateInfo, []sqlparser.SelectExpr) {
// Create the comparison expression.
compExpr := sqlparser.NewComparisonExpr(sqlparser.NullSafeEqualOp, updExpr.Name, updExpr.Expr, nil)
castedExpr := getCastedUpdateExpression(updatedTable, updExpr)
compExpr := sqlparser.NewComparisonExpr(sqlparser.NullSafeEqualOp, updExpr.Name, castedExpr, nil)
info := engine.NonLiteralUpdateInfo{
CompExprCol: -1,
UpdateExprCol: -1,
ExprCol: -1,
}
// Add the expressions to the select expressions. We make sure to reuse the offset if it has already been added once.
for idx, selectExpr := range exprs {
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, compExpr) {
info.CompExprCol = idx
}
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, updExpr.Expr) {
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, castedExpr) {
info.UpdateExprCol = idx
}
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, updExpr.Name) {
info.ExprCol = idx
}
}
// If the expression doesn't exist, then we add the expression and store the offset.
if info.CompExprCol == -1 {
Expand All @@ -356,15 +355,54 @@ func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updExpr *sql
}
if info.UpdateExprCol == -1 {
info.UpdateExprCol = len(exprs)
exprs = append(exprs, aeWrap(updExpr.Expr))
}
if info.ExprCol == -1 {
info.ExprCol = len(exprs)
exprs = append(exprs, aeWrap(updExpr.Name))
exprs = append(exprs, aeWrap(castedExpr))
}
return info, exprs
}

func getCastedUpdateExpression(updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr) sqlparser.Expr {
castTypeStr := getCastTypeForColumn(updatedTable, updExpr)
if castTypeStr == "" {
return updExpr.Expr
}
return &sqlparser.CastExpr{
Expr: updExpr.Expr,
Type: &sqlparser.ConvertType{
Type: castTypeStr,
},
}
}

func getCastTypeForColumn(updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr) string {
var ty querypb.Type
for _, column := range updatedTable.Columns {
if updExpr.Name.Name.Equal(column.Name) {
ty = column.Type
break
}
}
switch {
case sqltypes.IsNull(ty):
return ""
case sqltypes.IsSigned(ty):
return "SIGNED"
case sqltypes.IsUnsigned(ty):
return "UNSIGNED"
case sqltypes.IsFloat(ty):
return "FLOAT"
case sqltypes.IsDecimal(ty):
return "DECIMAL"
case sqltypes.IsDateOrTime(ty):
return "DATETIME"
case sqltypes.IsBinary(ty):
return "BINARY"
case sqltypes.IsText(ty):
return "CHAR"
default:
return ""
}
}

// createFkChildForUpdate creates the update query operator for the child table based on the foreign key constraints.
func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, selectOffsets []int, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, updatedTable *vindexes.Table) (*FkChild, error) {
// Create a ValTuple of child column names
Expand Down Expand Up @@ -484,6 +522,7 @@ func buildChildUpdOpForSetNull(
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL).
updateExprs := ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable))
compExpr := nullSafeNotInComparison(ctx,
updatedTable,
updateExprs, fk, updatedTable.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */)
if compExpr != nil {
childWhereExpr = &sqlparser.AndExpr{
Expand Down Expand Up @@ -522,6 +561,7 @@ func createFKVerifyOp(
updStmt *sqlparser.Update,
parentFks []vindexes.ParentFKInfo,
restrictChildFks []vindexes.ChildFKInfo,
updatedTable *vindexes.Table,
) (ops.Operator, error) {
if len(parentFks) == 0 && len(restrictChildFks) == 0 {
return childOp, nil
Expand All @@ -530,7 +570,7 @@ func createFKVerifyOp(
var Verify []*VerifyOp
// This validates that new values exists on the parent table.
for _, fk := range parentFks {
op, err := createFkVerifyOpForParentFKForUpdate(ctx, updStmt, fk)
op, err := createFkVerifyOpForParentFKForUpdate(ctx, updatedTable, updStmt, fk)
if err != nil {
return nil, err
}
Expand All @@ -541,7 +581,7 @@ func createFKVerifyOp(
}
// This validates that the old values don't exist on the child table.
for _, fk := range restrictChildFks {
op, err := createFkVerifyOpForChildFKForUpdate(ctx, updStmt, fk)
op, err := createFkVerifyOpForChildFKForUpdate(ctx, updatedTable, updStmt, fk)
if err != nil {
return nil, err
}
Expand All @@ -568,7 +608,7 @@ func createFKVerifyOp(
// where Parent.p1 is null and Parent.p2 is null and Child.id = 1 and Child.c2 + 1 is not null
// and Child.c2 is not null and not ((Child.c1) <=> (Child.c2 + 1))
// limit 1
func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) {
func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) {
childTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr)
childTbl, err := childTblExpr.TableName()
if err != nil {
Expand Down Expand Up @@ -608,7 +648,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS
}
} else {
notEqualColNames = append(notEqualColNames, prefixColNames(ctx, childTbl, matchedExpr.Name))
prefixedMatchExpr := prefixColNames(ctx, childTbl, matchedExpr.Expr)
prefixedMatchExpr := prefixColNames(ctx, childTbl, getCastedUpdateExpression(updatedTable, matchedExpr))
notEqualExprs = append(notEqualExprs, prefixedMatchExpr)
joinExpr = &sqlparser.ComparisonExpr{
Operator: sqlparser.EqualOp,
Expand Down Expand Up @@ -668,7 +708,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS
// verify query:
// select 1 from Child join Parent on Parent.p1 = Child.c1 and Parent.p2 = Child.c2
// where Parent.id = 1 and ((Parent.col + 1) IS NULL OR (child.c1) NOT IN ((Parent.col + 1))) limit 1
func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) {
func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) {
// ON UPDATE RESTRICT foreign keys that require validation, should only be allowed in the case where we
// are verifying all the FKs on vtgate level.
if !ctx.VerifyAllFKs {
Expand Down Expand Up @@ -710,7 +750,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt
// For example, if we are setting `update child cola = :v1 and colb = :v2`, then on the parent, the where condition would look something like this -
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL).
compExpr := nullSafeNotInComparison(ctx, updStmt.Exprs, cFk, parentTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
compExpr := nullSafeNotInComparison(ctx, updatedTable, updStmt.Exprs, cFk, parentTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
if compExpr != nil {
whereCond = sqlparser.AndExpressions(whereCond, compExpr)
}
Expand All @@ -735,7 +775,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL)
// This expression is used in cascading SET NULLs and in verifying whether an update should be restricted.
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
var valTuple sqlparser.ValTuple
var updateValues sqlparser.ValTuple
for idx, updateExpr := range updateExprs {
Expand All @@ -744,7 +784,7 @@ func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updateExprs sqlpa
if sqlparser.IsNull(updateExpr.Expr) {
return nil
}
childUpdateExpr := prefixColNames(ctx, parentTbl, updateExpr.Expr)
childUpdateExpr := prefixColNames(ctx, parentTbl, getCastedUpdateExpression(updatedTable, updateExpr))
if len(nonLiteralUpdateInfo) > 0 && nonLiteralUpdateInfo[idx].UpdateExprBvName != "" {
childUpdateExpr = sqlparser.NewArgument(nonLiteralUpdateInfo[idx].UpdateExprBvName)
}
Expand Down
Loading
Loading