From e3a9c6dac8e497fd00499bf80d89a0d6e1be9931 Mon Sep 17 00:00:00 2001
From: exialin <exialin37@gmail.com>
Date: Mon, 14 Jan 2019 09:47:52 +0800
Subject: [PATCH] *: fix the lower bound when converting numbers less than 0 to
 unsigned integers (#9028)

---
 executor/distsql.go           |  2 +-
 executor/executor.go          |  7 +++---
 executor/executor_test.go     |  1 +
 executor/load_data.go         |  2 +-
 executor/write_test.go        | 41 +++++++++++++++++++++++++++++++++++
 expression/builtin_cast.go    | 12 ++++++----
 expression/errors.go          |  4 ++--
 sessionctx/stmtctx/stmtctx.go | 22 ++++++++++++++++++-
 types/convert.go              | 15 +++++++++----
 types/datum.go                | 26 +++++++++++-----------
 10 files changed, 103 insertions(+), 29 deletions(-)

diff --git a/executor/distsql.go b/executor/distsql.go
index 9b8c5600d3f8e..22e326e2390b0 100644
--- a/executor/distsql.go
+++ b/executor/distsql.go
@@ -122,7 +122,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 {
 	var flags uint64
 	if sc.InInsertStmt {
 		flags |= model.FlagInInsertStmt
-	} else if sc.InUpdateOrDeleteStmt {
+	} else if sc.InUpdateStmt || sc.InDeleteStmt {
 		flags |= model.FlagInUpdateOrDeleteStmt
 	} else if sc.InSelectStmt {
 		flags |= model.FlagInSelectStmt
diff --git a/executor/executor.go b/executor/executor.go
index 8d47a4bab4968..91a58b8ea6ab7 100644
--- a/executor/executor.go
+++ b/executor/executor.go
@@ -1245,7 +1245,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
 	// pushing them down to TiKV as flags.
 	switch stmt := s.(type) {
 	case *ast.UpdateStmt:
-		sc.InUpdateOrDeleteStmt = true
+		sc.InUpdateStmt = true
 		sc.DupKeyAsWarning = stmt.IgnoreErr
 		sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
 		sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
@@ -1253,7 +1253,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
 		sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr
 		sc.Priority = stmt.Priority
 	case *ast.DeleteStmt:
-		sc.InUpdateOrDeleteStmt = true
+		sc.InDeleteStmt = true
 		sc.DupKeyAsWarning = stmt.IgnoreErr
 		sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
 		sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
@@ -1274,6 +1274,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
 		sc.DupKeyAsWarning = true
 		sc.BadNullAsWarning = true
 		sc.TruncateAsWarning = !vars.StrictSQLMode
+		sc.InLoadDataStmt = true
 	case *ast.SelectStmt:
 		sc.InSelectStmt = true
 
@@ -1314,7 +1315,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
 		sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID
 	}
 	sc.PrevAffectedRows = 0
-	if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt {
+	if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt {
 		sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows())
 	} else if vars.StmtCtx.InSelectStmt {
 		sc.PrevAffectedRows = -1
diff --git a/executor/executor_test.go b/executor/executor_test.go
index 5c6a53005713b..cd07756f560e3 100644
--- a/executor/executor_test.go
+++ b/executor/executor_test.go
@@ -329,6 +329,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo,
 		c.Assert(ctx.NewTxn(), IsNil)
 		ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true
 		ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true
+		ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true
 		data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2)
 		c.Assert(err1, IsNil)
 		c.Assert(reachLimit, IsFalse)
diff --git a/executor/load_data.go b/executor/load_data.go
index 258b5cfbd233d..74a5cff129fde 100644
--- a/executor/load_data.go
+++ b/executor/load_data.go
@@ -201,7 +201,7 @@ func (e *LoadDataInfo) getLine(prevData, curData []byte) ([]byte, []byte, bool)
 }
 
 // InsertData inserts data into specified table according to the specified format.
-// If it has the rest of data isn't completed the processing, then is returns without completed data.
+// If it has the rest of data isn't completed the processing, then it returns without completed data.
 // If the number of inserted rows reaches the batchRows, then the second return value is true.
 // If prevData isn't nil and curData is nil, there are no other data to deal with and the isEOF is true.
 func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error) {
diff --git a/executor/write_test.go b/executor/write_test.go
index 8ea7a5a572128..045c61158d9e7 100644
--- a/executor/write_test.go
+++ b/executor/write_test.go
@@ -227,6 +227,27 @@ func (s *testSuite) TestInsert(c *C) {
 	tk.MustExec("insert into test values(2, 3)")
 	tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3"))
 
+	// issue 6360
+	tk.MustExec("drop table if exists t;")
+	tk.MustExec("create table t(a bigint unsigned);")
+	tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';")
+	_, err = tk.Exec("insert into t value (-1);")
+	c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue)
+	tk.MustExec("set @@sql_mode = '';")
+	tk.MustExec("insert into t value (-1);")
+	// TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs
+	tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
+	tk.MustExec("insert into t select -1;")
+	tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
+	tk.MustExec("insert into t select cast(-1 as unsigned);")
+	tk.MustExec("insert into t value (-1.111);")
+	tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint"))
+	tk.MustExec("insert into t value ('-1.111');")
+	tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'"))
+	r = tk.MustQuery("select * from t;")
+	r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0"))
+	tk.MustExec("set @@sql_mode = @orig_sql_mode;")
+
 	// issue 6424
 	tk.MustExec("drop table if exists t")
 	tk.MustExec("create table t(a time(6))")
@@ -1699,6 +1720,26 @@ func (s *testSuite) TestLoadDataIgnoreLines(c *C) {
 	checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL)
 }
 
+// related to issue 6360
+func (s *testSuite) TestLoadDataOverflowBigintUnsigned(c *C) {
+	tk := testkit.NewTestKit(c, s.store)
+	tk.MustExec("use test; drop table if exists load_data_test;")
+	tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);")
+	tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test")
+	ctx := tk.Se.(sessionctx.Context)
+	ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo)
+	c.Assert(ok, IsTrue)
+	defer ctx.SetValue(executor.LoadDataVarKey, nil)
+	c.Assert(ld, NotNil)
+	tests := []testCase{
+		{nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil},
+		{nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil},
+	}
+	deleteSQL := "delete from load_data_test"
+	selectSQL := "select * from load_data_test;"
+	checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL)
+}
+
 func (s *testSuite) TestBatchInsertDelete(c *C) {
 	originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit)
 	defer func() {
diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go
index c61d97f6a4924..403ba83457172 100644
--- a/expression/builtin_cast.go
+++ b/expression/builtin_cast.go
@@ -464,7 +464,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b
 		res = 0
 	} else {
 		var uVal uint64
-		uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
+		sc := b.ctx.GetSessionVars().StmtCtx
+		uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
 		res = float64(uVal)
 	}
 	return res, false, errors.Trace(err)
@@ -491,7 +492,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe
 		res = &types.MyDecimal{}
 	} else {
 		var uVal uint64
-		uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
+		sc := b.ctx.GetSessionVars().StmtCtx
+		uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
 		if err != nil {
 			return res, false, errors.Trace(err)
 		}
@@ -520,7 +522,8 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul
 		res = strconv.FormatInt(val, 10)
 	} else {
 		var uVal uint64
-		uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
+		sc := b.ctx.GetSessionVars().StmtCtx
+		uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
 		if err != nil {
 			return res, false, errors.Trace(err)
 		}
@@ -747,7 +750,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool
 		res = 0
 	} else {
 		var uintVal uint64
-		uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
+		sc := b.ctx.GetSessionVars().StmtCtx
+		uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble)
 		res = int64(uintVal)
 	}
 	return res, isNull, errors.Trace(err)
diff --git a/expression/errors.go b/expression/errors.go
index 4b1a3163e1453..ce35d125bddbf 100644
--- a/expression/errors.go
+++ b/expression/errors.go
@@ -70,7 +70,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error {
 		return err
 	}
 	sc := ctx.GetSessionVars().StmtCtx
-	if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) {
+	if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) {
 		return err
 	}
 	sc.AppendWarning(err)
@@ -80,7 +80,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error {
 // handleDivisionByZeroError reports error or warning depend on the context.
 func handleDivisionByZeroError(ctx sessionctx.Context) error {
 	sc := ctx.GetSessionVars().StmtCtx
-	if sc.InInsertStmt || sc.InUpdateOrDeleteStmt {
+	if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt {
 		if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() {
 			return nil
 		}
diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go
index eecb56d1ac05c..6819b2eaeb8d2 100644
--- a/sessionctx/stmtctx/stmtctx.go
+++ b/sessionctx/stmtctx/stmtctx.go
@@ -47,8 +47,10 @@ type StatementContext struct {
 	// If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker.
 	IsDDLJobInQueue        bool
 	InInsertStmt           bool
-	InUpdateOrDeleteStmt   bool
+	InUpdateStmt           bool
+	InDeleteStmt           bool
 	InSelectStmt           bool
+	InLoadDataStmt         bool
 	IgnoreTruncate         bool
 	IgnoreZeroInDate       bool
 	DupKeyAsWarning        bool
@@ -276,3 +278,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
 	sc.mu.Unlock()
 	return details
 }
+
+// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
+// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode.
+// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
+func (sc *StatementContext) ShouldClipToZero() bool {
+	// TODO: Currently altering column of integer to unsigned integer is not supported.
+	// If it is supported one day, that case should be added here.
+	return sc.InInsertStmt || sc.InLoadDataStmt
+}
+
+// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows,
+// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types.
+func (sc *StatementContext) ShouldIgnoreOverflowError() bool {
+	if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt {
+		return true
+	}
+	return false
+}
diff --git a/types/convert.go b/types/convert.go
index 27df51b771997..87e3cd82e24a8 100644
--- a/types/convert.go
+++ b/types/convert.go
@@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
 }
 
 // ConvertIntToUint converts an int value to an uint value.
-func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) {
+func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
+	if sc.ShouldClipToZero() && val < 0 {
+		return 0, overflow(val, tp)
+	}
+
 	if uint64(val) > upperBound {
 		return upperBound, overflow(val, tp)
 	}
@@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
 }
 
 // ConvertFloatToUint converts a float value to an uint value.
-func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) {
+func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
 	val := RoundFloat(fval)
 	if val < 0 {
+		if sc.ShouldClipToZero() {
+			return 0, overflow(val, tp)
+		}
 		return uint64(int64(val)), overflow(val, tp)
 	}
 
@@ -400,7 +407,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned
 			return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble)
 		}
 		bound := UnsignedUpperBound[mysql.TypeLonglong]
-		u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble)
+		u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble)
 		return int64(u), errors.Trace(err)
 	case json.TypeCodeString:
 		return StrToInt(sc, hack.String(j.GetString()))
@@ -423,7 +430,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6
 	case json.TypeCodeInt64:
 		return float64(j.GetInt64()), nil
 	case json.TypeCodeUint64:
-		u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
+		u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong)
 		return float64(u), errors.Trace(err)
 	case json.TypeCodeFloat64:
 		return j.GetFloat64(), nil
diff --git a/types/datum.go b/types/datum.go
index 86d94792f2762..33558c416cfdb 100644
--- a/types/datum.go
+++ b/types/datum.go
@@ -865,21 +865,21 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
 	)
 	switch d.k {
 	case KindInt64:
-		val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp)
+		val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp)
 	case KindUint64:
 		val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp)
 	case KindFloat32, KindFloat64:
-		val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp)
+		val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp)
 	case KindString, KindBytes:
-		val, err = StrToUint(sc, d.GetString())
-		if err != nil {
-			return ret, errors.Trace(err)
+		uval, err1 := StrToUint(sc, d.GetString())
+		if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() {
+			return ret, errors.Trace(err1)
 		}
-		val, err = ConvertUintToUint(val, upperBound, tp)
+		val, err = ConvertUintToUint(uval, upperBound, tp)
 		if err != nil {
 			return ret, errors.Trace(err)
 		}
-		ret.SetUint64(val)
+		err = err1
 	case KindMysqlTime:
 		dec := d.GetMysqlTime().ToNumber()
 		err = dec.Round(dec, 0, ModeHalfEven)
@@ -887,7 +887,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
 		if err == nil {
 			err = err1
 		}
-		val, err1 = ConvertIntToUint(ival, upperBound, tp)
+		val, err1 = ConvertIntToUint(sc, ival, upperBound, tp)
 		if err == nil {
 			err = err1
 		}
@@ -896,18 +896,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
 		err = dec.Round(dec, 0, ModeHalfEven)
 		ival, err1 := dec.ToInt()
 		if err1 == nil {
-			val, err = ConvertIntToUint(ival, upperBound, tp)
+			val, err = ConvertIntToUint(sc, ival, upperBound, tp)
 		}
 	case KindMysqlDecimal:
 		fval, err1 := d.GetMysqlDecimal().ToFloat64()
-		val, err = ConvertFloatToUint(fval, upperBound, tp)
+		val, err = ConvertFloatToUint(sc, fval, upperBound, tp)
 		if err == nil {
 			err = err1
 		}
 	case KindMysqlEnum:
-		val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp)
+		val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
 	case KindMysqlSet:
-		val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp)
+		val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp)
 	case KindBinaryLiteral, KindMysqlBit:
 		val, err = d.GetBinaryLiteral().ToInt(sc)
 	case KindMysqlJSON:
@@ -1137,7 +1137,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem
 				return nil, errors.Trace(err)
 			}
 			if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 {
-				if sc.InInsertStmt || sc.InUpdateOrDeleteStmt {
+				if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt {
 					// fix https://github.com/pingcap/tidb/issues/3895
 					// fix https://github.com/pingcap/tidb/issues/5532
 					sc.AppendWarning(ErrTruncated)