diff --git a/executor/builder.go b/executor/builder.go index 30f49613ba20c..e10a2b6b11970 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1205,15 +1205,6 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } } - // consider collations - leftTypes := make([]*types.FieldType, 0, len(retTypes(leftExec))) - for _, tp := range retTypes(leftExec) { - leftTypes = append(leftTypes, tp.Clone()) - } - rightTypes := make([]*types.FieldType, 0, len(retTypes(rightExec))) - for _, tp := range retTypes(rightExec) { - rightTypes = append(rightTypes, tp.Clone()) - } leftIsBuildSide := true e.isNullEQ = v.IsNullEQ @@ -1256,24 +1247,32 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } executorCountHashJoinExec.Inc() + // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. + // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. + // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, + // and use ETString to hash the second column, although they may be the same column. + leftExecTypes, rightExecTypes := retTypes(leftExec), retTypes(rightExec) + leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)) + for i, col := range v.LeftJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i].Flag = col.RetType.Flag + } + for i, col := range v.RightJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i].Flag = col.RetType.Flag + } + + // consider collations for i := range v.EqualConditions { chs, coll := v.EqualConditions[i].CharsetAndCollation(e.ctx) - bt := leftTypes[v.LeftJoinKeys[i].Index] - bt.Charset, bt.Collate = chs, coll - pt := rightTypes[v.RightJoinKeys[i].Index] - pt.Charset, pt.Collate = chs, coll + leftTypes[i].Charset, leftTypes[i].Collate = chs, coll + rightTypes[i].Charset, rightTypes[i].Collate = chs, coll } if leftIsBuildSide { e.buildTypes, e.probeTypes = leftTypes, rightTypes } else { e.buildTypes, e.probeTypes = rightTypes, leftTypes } - for _, key := range e.buildKeys { - e.buildTypes[key.Index].Flag = key.RetType.Flag - } - for _, key := range e.probeKeys { - e.probeTypes[key.Index].Flag = key.RetType.Flag - } return e } @@ -2701,6 +2700,21 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) outerTypes[col.Index].Flag = col.RetType.Flag } + // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. + // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. + // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, + // and use ETString to hash the second column, although they may be the same column. + innerHashTypes := make([]*types.FieldType, len(v.InnerHashKeys)) + outerHashTypes := make([]*types.FieldType, len(v.OuterHashKeys)) + for i, col := range v.InnerHashKeys { + innerHashTypes[i] = innerTypes[col.Index].Clone() + innerHashTypes[i].Flag = col.RetType.Flag + } + for i, col := range v.OuterHashKeys { + outerHashTypes[i] = outerTypes[col.Index].Clone() + outerHashTypes[i].Flag = col.RetType.Flag + } + var ( outerFilter []expression.Expression leftTypes, rightTypes []*types.FieldType @@ -2735,12 +2749,14 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) e := &IndexLookUpJoin{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec), outerCtx: outerCtx{ - rowTypes: outerTypes, - filter: outerFilter, + rowTypes: outerTypes, + hashTypes: outerHashTypes, + filter: outerFilter, }, innerCtx: innerCtx{ readerBuilder: &dataReaderBuilder{Plan: innerPlan, executorBuilder: b}, rowTypes: innerTypes, + hashTypes: innerHashTypes, colLens: v.IdxColLens, hasPrefixCol: hasPrefixCol, }, diff --git a/executor/hash_table.go b/executor/hash_table.go index 0f5dc311ccd60..b22f98bbef501 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -34,6 +34,7 @@ import ( // hashContext keeps the needed hash context of a db table in hash join. type hashContext struct { + // allTypes one-to-one correspondence with keyColIdx allTypes []*types.FieldType keyColIdx []int buf []byte @@ -84,9 +85,9 @@ type hashRowContainer struct { rowContainer *chunk.RowContainer } -func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContext) *hashRowContainer { +func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContext, allTypes []*types.FieldType) *hashRowContainer { maxChunkSize := sCtx.GetSessionVars().MaxChunkSize - rc := chunk.NewRowContainer(hCtx.allTypes, maxChunkSize) + rc := chunk.NewRowContainer(allTypes, maxChunkSize) c := &hashRowContainer{ sc: sCtx.GetSessionVars().StmtCtx, hCtx: hCtx, @@ -171,7 +172,7 @@ func (c *hashRowContainer) PutChunkSelected(chk *chunk.Chunk, selected, ignoreNu hCtx := c.hCtx for keyIdx, colIdx := range c.hCtx.keyColIdx { ignoreNull := len(ignoreNulls) > keyIdx && ignoreNulls[keyIdx] - err := codec.HashChunkSelected(c.sc, hCtx.hashVals, chk, hCtx.allTypes[colIdx], colIdx, hCtx.buf, hCtx.hasNull, selected, ignoreNull) + err := codec.HashChunkSelected(c.sc, hCtx.hashVals, chk, hCtx.allTypes[keyIdx], colIdx, hCtx.buf, hCtx.hasNull, selected, ignoreNull) if err != nil { return errors.Trace(err) } diff --git a/executor/hash_table_serial_test.go b/executor/hash_table_serial_test.go index 3def2a0871adc..50ef8adb72f5e 100644 --- a/executor/hash_table_serial_test.go +++ b/executor/hash_table_serial_test.go @@ -126,7 +126,7 @@ func testHashRowContainer(t *testing.T, hashFunc func() hash.Hash64, spill bool) for i := 0; i < numRows; i++ { hCtx.hashVals = append(hCtx.hashVals, hashFunc()) } - rowContainer := newHashRowContainer(sctx, 0, hCtx) + rowContainer := newHashRowContainer(sctx, 0, hCtx, hCtx.allTypes) copiedRC = rowContainer.ShallowCopy() tracker := rowContainer.GetMemTracker() tracker.SetLabel(memory.LabelForBuildSideResult) diff --git a/executor/index_lookup_hash_join.go b/executor/index_lookup_hash_join.go index e959957dc1cba..b77e446c62104 100644 --- a/executor/index_lookup_hash_join.go +++ b/executor/index_lookup_hash_join.go @@ -554,7 +554,7 @@ func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(ctx context.Con } } h.Reset() - err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx, h, row, iw.outerCtx.rowTypes, hashColIdx, buf) + err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx, h, row, iw.outerCtx.hashTypes, hashColIdx, buf) failpoint.Inject("testIndexHashJoinBuildErr", func() { err = errors.New("mockIndexHashJoinBuildErr") }) @@ -645,7 +645,7 @@ func (iw *indexHashJoinInnerWorker) doJoinUnordered(ctx context.Context, task *i func (iw *indexHashJoinInnerWorker) getMatchedOuterRows(innerRow chunk.Row, task *indexHashJoinTask, h hash.Hash64, buf []byte) (matchedRows []chunk.Row, matchedRowPtr []chunk.RowPtr, err error) { h.Reset() - err = codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx, h, innerRow, iw.rowTypes, iw.hashCols, buf) + err = codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx, h, innerRow, iw.hashTypes, iw.hashCols, buf) if err != nil { return nil, nil, err } @@ -659,7 +659,7 @@ func (iw *indexHashJoinInnerWorker) getMatchedOuterRows(innerRow chunk.Row, task matchedRowPtr = make([]chunk.RowPtr, 0, len(iw.matchedOuterPtrs)) for _, ptr := range iw.matchedOuterPtrs { outerRow := task.outerResult.GetRow(ptr) - ok, err := codec.EqualChunkRow(iw.ctx.GetSessionVars().StmtCtx, innerRow, iw.rowTypes, iw.keyCols, outerRow, iw.outerCtx.rowTypes, iw.outerCtx.hashCols) + ok, err := codec.EqualChunkRow(iw.ctx.GetSessionVars().StmtCtx, innerRow, iw.hashTypes, iw.hashCols, outerRow, iw.outerCtx.hashTypes, iw.outerCtx.hashCols) if err != nil { return nil, nil, err } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index bb1caa7de3881..5f4945f3fd55c 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -86,10 +86,11 @@ type IndexLookUpJoin struct { } type outerCtx struct { - rowTypes []*types.FieldType - keyCols []int - hashCols []int - filter expression.CNFExprs + rowTypes []*types.FieldType + keyCols []int + hashTypes []*types.FieldType + hashCols []int + filter expression.CNFExprs } type innerCtx struct { @@ -97,6 +98,7 @@ type innerCtx struct { rowTypes []*types.FieldType keyCols []int keyColIDs []int64 // the original ID in its table, used by dynamic partition pruning + hashTypes []*types.FieldType hashCols []int colLens []int hasPrefixCol bool diff --git a/executor/join.go b/executor/join.go index 716ce2a229d3f..2b97c0b1f93cf 100644 --- a/executor/join.go +++ b/executor/join.go @@ -566,7 +566,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx hCtx.initHash(probeSideChk.NumRows()) for keyIdx, i := range hCtx.keyColIdx { ignoreNull := len(e.isNullEQ) > keyIdx && e.isNullEQ[keyIdx] - err = codec.HashChunkSelected(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[i], i, hCtx.buf, hCtx.hasNull, selected, ignoreNull) + err = codec.HashChunkSelected(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[keyIdx], i, hCtx.buf, hCtx.hasNull, selected, ignoreNull) if err != nil { joinResult.err = err return false, joinResult @@ -607,8 +607,8 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx // join2ChunkForOuterHashJoin joins chunks when using the outer to build a hash table (refer to outer hash join) func (e *HashJoinExec) join2ChunkForOuterHashJoin(workerID uint, probeSideChk *chunk.Chunk, hCtx *hashContext, rowContainer *hashRowContainer, joinResult *hashjoinWorkerResult) (ok bool, _ *hashjoinWorkerResult) { hCtx.initHash(probeSideChk.NumRows()) - for _, i := range hCtx.keyColIdx { - err := codec.HashChunkColumns(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[i], i, hCtx.buf, hCtx.hasNull) + for keyIdx, i := range hCtx.keyColIdx { + err := codec.HashChunkColumns(rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[keyIdx], i, hCtx.buf, hCtx.hasNull) if err != nil { joinResult.err = err return false, joinResult @@ -656,7 +656,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { allTypes: e.buildTypes, keyColIdx: buildKeyColIdx, } - e.rowContainer = newHashRowContainer(e.ctx, int(e.buildSideEstCount), hCtx) + e.rowContainer = newHashRowContainer(e.ctx, int(e.buildSideEstCount), hCtx, retTypes(e.buildSideExec)) // we shallow copies rowContainer for each probe worker to avoid lock contention e.rowContainerForProbe = make([]*hashRowContainer, e.concurrency) for i := uint(0); i < e.concurrency; i++ { diff --git a/expression/builtin_control.go b/expression/builtin_control.go index e9b39bf36ab5c..6ea2c119176c7 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -237,6 +237,14 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre return nil, err } bf.tp = fieldTp + if fieldTp.Tp == mysql.TypeEnum || fieldTp.Tp == mysql.TypeSet { + switch tp { + case types.ETInt: + fieldTp.Tp = mysql.TypeLonglong + case types.ETString: + fieldTp.Tp = mysql.TypeVarchar + } + } switch tp { case types.ETInt: diff --git a/expression/constant_fold.go b/expression/constant_fold.go index f08a5c45abf60..d7cbaf5d8edd6 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -143,7 +143,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { foldedExpr.GetType().Decimal = expr.GetType().Decimal return foldedExpr, isDeferredConst } - return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + return foldedExpr, isDeferredConst } return expr, isDeferredConst } diff --git a/expression/integration_test.go b/expression/integration_test.go index d071bcc4ec7a1..14894600079e2 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9936,6 +9936,19 @@ func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) { tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);") tk.MustQuery("select if(A, null,b)=1 from t;").Check(testkit.Rows("", "", "", "")) tk.MustQuery("select if(A, null,b)='a' from t;").Check(testkit.Rows("", "", "", "")) + + // issue 29357 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(`a` enum('y','b','Abc','null','1','2','0')) CHARSET=binary;") + tk.MustExec("insert into t values(\"1\");") + tk.MustQuery("SELECT count(*) from t where (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end);\n").Check(testkit.Rows("0")) + tk.MustQuery("SELECT (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end) from t;\n").Check(testkit.Rows("")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1")) } func (s *testIntegrationSuite) TestComplexShowVariables(c *C) { @@ -10432,6 +10445,20 @@ func (s *testIntegrationSuite) TestIssue28643(c *C) { tk.MustQuery("select hour(a) from t;").Check(testkit.Rows("838", "838")) } +func (s *testIntegrationSuite) TestIssue27831(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a enum(\"a\", \"b\"), b enum(\"a\", \"b\"), c bool)") + tk.MustExec("insert into t values(\"a\", \"a\", 1);") + tk.MustQuery("select * from t t1 right join t t2 on t1.a=t2.b and t1.a= t2.c;").Check(testkit.Rows("a a 1 a a 1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a enum(\"a\", \"b\"), b enum(\"a\", \"b\"), c bool, d int, index idx(d))") + tk.MustExec("insert into t values(\"a\", \"a\", 1, 1);") + tk.MustQuery("select /*+ inl_hash_join(t1) */ * from t t1 right join t t2 on t1.a=t2.b and t1.a= t2.c and t1.d=t2.d;").Check(testkit.Rows("a a 1 1 a a 1 1")) +} + func (s *testIntegrationSuite) TestIssue29434(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/util/codec/codec.go b/util/codec/codec.go index f08c113a61c81..f6daf7e54261d 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -665,8 +665,8 @@ func HashChunkSelected(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk // If two rows are logically equal, it will generate the same bytes. func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) { var b []byte - for _, idx := range colIdx { - buf[0], b, err = encodeHashChunkRowIdx(sc, row, allTypes[idx], idx) + for i, idx := range colIdx { + buf[0], b, err = encodeHashChunkRowIdx(sc, row, allTypes[i], idx) if err != nil { return errors.Trace(err) } @@ -688,13 +688,16 @@ func EqualChunkRow(sc *stmtctx.StatementContext, row1 chunk.Row, allTypes1 []*types.FieldType, colIdx1 []int, row2 chunk.Row, allTypes2 []*types.FieldType, colIdx2 []int, ) (bool, error) { + if len(colIdx1) != len(colIdx2) { + return false, errors.Errorf("Internal error: Hash columns count mismatch, col1: %d, col2: %d", len(colIdx1), len(colIdx2)) + } for i := range colIdx1 { idx1, idx2 := colIdx1[i], colIdx2[i] - flag1, b1, err := encodeHashChunkRowIdx(sc, row1, allTypes1[idx1], idx1) + flag1, b1, err := encodeHashChunkRowIdx(sc, row1, allTypes1[i], idx1) if err != nil { return false, errors.Trace(err) } - flag2, b2, err := encodeHashChunkRowIdx(sc, row2, allTypes2[idx2], idx2) + flag2, b2, err := encodeHashChunkRowIdx(sc, row2, allTypes2[i], idx2) if err != nil { return false, errors.Trace(err) } diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index e1719bb7345e8..ac41a148160cd 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -1280,9 +1280,9 @@ func TestHashChunkColumns(t *testing.T) { for i := 0; i < 12; i++ { require.True(t, chk.GetRow(0).IsNull(i)) err1 := HashChunkSelected(sc, vecHash, chk, tps[i], i, buf, hasNull, sel, false) - err2 := HashChunkRow(sc, rowHash[0], chk.GetRow(0), tps, colIdx[i:i+1], buf) - err3 := HashChunkRow(sc, rowHash[1], chk.GetRow(1), tps, colIdx[i:i+1], buf) - err4 := HashChunkRow(sc, rowHash[2], chk.GetRow(2), tps, colIdx[i:i+1], buf) + err2 := HashChunkRow(sc, rowHash[0], chk.GetRow(0), tps[i:i+1], colIdx[i:i+1], buf) + err3 := HashChunkRow(sc, rowHash[1], chk.GetRow(1), tps[i:i+1], colIdx[i:i+1], buf) + err4 := HashChunkRow(sc, rowHash[2], chk.GetRow(2), tps[i:i+1], colIdx[i:i+1], buf) require.NoError(t, err1) require.NoError(t, err2) require.NoError(t, err3) @@ -1305,9 +1305,9 @@ func TestHashChunkColumns(t *testing.T) { require.False(t, chk.GetRow(0).IsNull(i)) err1 := HashChunkSelected(sc, vecHash, chk, tps[i], i, buf, hasNull, sel, false) - err2 := HashChunkRow(sc, rowHash[0], chk.GetRow(0), tps, colIdx[i:i+1], buf) - err3 := HashChunkRow(sc, rowHash[1], chk.GetRow(1), tps, colIdx[i:i+1], buf) - err4 := HashChunkRow(sc, rowHash[2], chk.GetRow(2), tps, colIdx[i:i+1], buf) + err2 := HashChunkRow(sc, rowHash[0], chk.GetRow(0), tps[i:i+1], colIdx[i:i+1], buf) + err3 := HashChunkRow(sc, rowHash[1], chk.GetRow(1), tps[i:i+1], colIdx[i:i+1], buf) + err4 := HashChunkRow(sc, rowHash[2], chk.GetRow(2), tps[i:i+1], colIdx[i:i+1], buf) require.NoError(t, err1) require.NoError(t, err2)