diff --git a/ddl/db_test.go b/ddl/db_test.go index 14b0dbfaadbec..0337445c4aa7b 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -2553,6 +2553,22 @@ func (s *testDBSuite) TestAddIndexForGeneratedColumn(c *C) { s.mustExec(c, "delete from t where y = 2155") s.mustExec(c, "alter table t add index idx_y(y1)") s.mustExec(c, "alter table t drop index idx_y") + + // Fix issue 9311. + s.tk.MustExec("create table gcai_table (id int primary key);") + s.tk.MustExec("insert into gcai_table values(1);") + s.tk.MustExec("ALTER TABLE gcai_table ADD COLUMN d date DEFAULT '9999-12-31';") + s.tk.MustExec("ALTER TABLE gcai_table ADD COLUMN d1 date as (DATE_SUB(d, INTERVAL 31 DAY));") + s.tk.MustExec("ALTER TABLE gcai_table ADD INDEX idx(d1);") + s.tk.MustQuery("select * from gcai_table").Check(testkit.Rows("1 9999-12-31 9999-11-30")) + s.tk.MustQuery("select d1 from gcai_table use index(idx)").Check(testkit.Rows("9999-11-30")) + s.tk.MustExec("admin check table gcai_table") + // The column is PKIsHandle in generated column expression. + s.tk.MustExec("ALTER TABLE gcai_table ADD COLUMN id1 int as (id+5);") + s.tk.MustExec("ALTER TABLE gcai_table ADD INDEX idx1(id1);") + s.tk.MustQuery("select * from gcai_table").Check(testkit.Rows("1 9999-12-31 9999-11-30 6")) + s.tk.MustQuery("select id1 from gcai_table use index(idx1)").Check(testkit.Rows("6")) + s.tk.MustExec("admin check table gcai_table") } func (s *testDBSuite) TestIssue9100(c *C) { diff --git a/ddl/index.go b/ddl/index.go index 95571d66dd18f..25e01d9dc0a40 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -517,7 +517,7 @@ func mergeAddIndexCtxToResult(taskCtx *addIndexTaskContext, result *addIndexResu func newAddIndexWorker(sessCtx sessionctx.Context, worker *worker, id int, t table.PhysicalTable, indexInfo *model.IndexInfo, decodeColMap map[int64]decoder.Column) *addIndexWorker { index := tables.NewIndex(t.GetPhysicalID(), t.Meta(), indexInfo) - rowDecoder := decoder.NewRowDecoder(t.Cols(), decodeColMap) + rowDecoder := decoder.NewRowDecoder(t, decodeColMap) return &addIndexWorker{ id: id, ddlWorker: worker, @@ -547,7 +547,7 @@ func (w *addIndexWorker) getIndexRecord(handle int64, recordKey []byte, rawRecor cols := t.Cols() idxInfo := w.index.Meta() sysZone := timeutil.SystemLocation() - _, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.sessCtx, rawRecord, time.UTC, sysZone, w.rowMap) + _, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.sessCtx, handle, rawRecord, time.UTC, sysZone, w.rowMap) if err != nil { return nil, errors.Trace(errCantDecodeIndex.GenWithStackByArgs(err)) } @@ -899,13 +899,13 @@ func makeupDecodeColMap(sessCtx sessionctx.Context, t table.Table, indexInfo *mo for _, v := range indexInfo.Columns { col := cols[v.Offset] tpExpr := decoder.Column{ - Info: col.ToInfo(), + Col: col, } if col.IsGenerated() && !col.GeneratedStored { for _, c := range cols { if _, ok := col.Dependences[c.Name.L]; ok { decodeColMap[c.ID] = decoder.Column{ - Info: c.ToInfo(), + Col: c, } } } diff --git a/util/admin/admin.go b/util/admin/admin.go index e591453d729bc..15a8cedad2a9b 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -599,13 +599,13 @@ func makeRowDecoder(t table.Table, decodeCol []*table.Column, genExpr map[model. for _, v := range decodeCol { col := cols[v.Offset] tpExpr := decoder.Column{ - Info: col.ToInfo(), + Col: col, } if col.IsGenerated() && !col.GeneratedStored { for _, c := range cols { if _, ok := col.Dependences[c.Name.L]; ok { decodeColsMap[c.ID] = decoder.Column{ - Info: c.ToInfo(), + Col: c, } } } @@ -613,7 +613,7 @@ func makeRowDecoder(t table.Table, decodeCol []*table.Column, genExpr map[model. } decodeColsMap[col.ID] = tpExpr } - return decoder.NewRowDecoder(cols, decodeColsMap) + return decoder.NewRowDecoder(t, decodeColsMap) } // genExprs use to calculate generated column value. @@ -641,7 +641,7 @@ func rowWithCols(sessCtx sessionctx.Context, txn kv.Retriever, t table.Table, h } } - rowMap, err := rowDecoder.DecodeAndEvalRowWithMap(sessCtx, value, sessCtx.GetSessionVars().Location(), time.UTC, nil) + rowMap, err := rowDecoder.DecodeAndEvalRowWithMap(sessCtx, h, value, sessCtx.GetSessionVars().Location(), time.UTC, nil) if err != nil { return nil, errors.Trace(err) } @@ -702,7 +702,7 @@ func iterRecords(sessCtx sessionctx.Context, retriever kv.Retriever, t table.Tab return errors.Trace(err) } - rowMap, err := rowDecoder.DecodeAndEvalRowWithMap(sessCtx, it.Value(), sessCtx.GetSessionVars().Location(), time.UTC, nil) + rowMap, err := rowDecoder.DecodeAndEvalRowWithMap(sessCtx, handle, it.Value(), sessCtx.GetSessionVars().Location(), time.UTC, nil) if err != nil { return errors.Trace(err) } diff --git a/util/rowDecoder/decoder.go b/util/rowDecoder/decoder.go index 4ea12d1394299..9768eca91c8a8 100644 --- a/util/rowDecoder/decoder.go +++ b/util/rowDecoder/decoder.go @@ -17,11 +17,11 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -29,24 +29,26 @@ import ( // Column contains the info and generated expr of column. type Column struct { - Info *model.ColumnInfo + Col *table.Column GenExpr expression.Expression } // RowDecoder decodes a byte slice into datums and eval the generated column value. type RowDecoder struct { + tbl table.Table mutRow chunk.MutRow columns map[int64]Column colTypes map[int64]*types.FieldType haveGenColumn bool + defaultVals []types.Datum } // NewRowDecoder returns a new RowDecoder. -func NewRowDecoder(cols []*table.Column, decodeColMap map[int64]Column) *RowDecoder { +func NewRowDecoder(tbl table.Table, decodeColMap map[int64]Column) *RowDecoder { colFieldMap := make(map[int64]*types.FieldType, len(decodeColMap)) haveGenCol := false for id, col := range decodeColMap { - colFieldMap[id] = &col.Info.FieldType + colFieldMap[id] = &col.Col.ColumnInfo.FieldType if col.GenExpr != nil { haveGenCol = true } @@ -57,20 +59,23 @@ func NewRowDecoder(cols []*table.Column, decodeColMap map[int64]Column) *RowDeco } } + cols := tbl.Cols() tps := make([]*types.FieldType, len(cols)) for _, col := range cols { tps[col.Offset] = &col.FieldType } return &RowDecoder{ + tbl: tbl, mutRow: chunk.MutRowFromTypes(tps), columns: decodeColMap, colTypes: colFieldMap, haveGenColumn: haveGenCol, + defaultVals: make([]types.Datum, len(cols)), } } // DecodeAndEvalRowWithMap decodes a byte slice into datums and evaluates the generated column value. -func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, b []byte, decodeLoc, sysLoc *time.Location, row map[int64]types.Datum) (map[int64]types.Datum, error) { +func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, handle int64, b []byte, decodeLoc, sysLoc *time.Location, row map[int64]types.Datum) (map[int64]types.Datum, error) { row, err := tablecodec.DecodeRowWithMap(b, rd.colTypes, decodeLoc, row) if err != nil { return nil, errors.Trace(err) @@ -79,8 +84,28 @@ func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, b []byte, return row, nil } - for id, v := range row { - rd.mutRow.SetValue(rd.columns[id].Info.Offset, v.GetValue()) + for _, dCol := range rd.columns { + colInfo := dCol.Col.ColumnInfo + val, ok := row[colInfo.ID] + if ok || dCol.GenExpr != nil { + rd.mutRow.SetValue(colInfo.Offset, val.GetValue()) + continue + } + + // Get the default value of the column in the generated column expression. + if dCol.Col.IsPKHandleColumn(rd.tbl.Meta()) { + if mysql.HasUnsignedFlag(colInfo.Flag) { + val.SetUint64(uint64(handle)) + } else { + val.SetInt64(handle) + } + } else { + val, err = tables.GetColDefaultValue(ctx, dCol.Col, rd.defaultVals) + if err != nil { + return nil, errors.Trace(err) + } + } + rd.mutRow.SetValue(colInfo.Offset, val.GetValue()) } for id, col := range rd.columns { if col.GenExpr == nil { @@ -91,7 +116,7 @@ func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, b []byte, if err != nil { return nil, errors.Trace(err) } - val, err = table.CastValue(ctx, val, col.Info) + val, err = table.CastValue(ctx, val, col.Col.ColumnInfo) if err != nil { return nil, errors.Trace(err) }