diff --git a/enginetest/join_planning_tests.go b/enginetest/join_planning_tests.go index 5cbcbf959c..5d0f15e0a3 100644 --- a/enginetest/join_planning_tests.go +++ b/enginetest/join_planning_tests.go @@ -1061,12 +1061,12 @@ join uv d on d.u = c.x`, }, }, { - name: "primary key range join", + name: "indexed range join", setup: []string{ - "create table vals (val int primary key);", - "create table ranges (min int primary key, max int, unique key(min,max));", - "insert into vals values (0), (1), (2), (3), (4), (5), (6);", - "insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6);", + "create table vals (val int unique key);", + "create table ranges (min int unique key, max int, unique key(min,max));", + "insert into vals values (null), (0), (1), (2), (3), (4), (5), (6);", + "insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6);", }, tests: []JoinPlanTest{ { @@ -1247,8 +1247,9 @@ join uv d on d.u = c.x`, }, { q: "select * from vals where exists (select * from vals join ranges on val between min and max where min >= 2 and max <= 5)", - types: []plan.JoinType{plan.JoinTypeCross, plan.JoinTypeInner}, + types: []plan.JoinType{plan.JoinTypeCrossHash, plan.JoinTypeInner}, exp: []sql.Row{ + {nil}, {0}, {1}, {2}, @@ -1297,18 +1298,20 @@ join uv d on d.u = c.x`, {6}, }, }, - { - q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);", - types: []plan.JoinType{plan.JoinTypeSemi}, - exp: []sql.Row{ - {1}, - {2}, - {3}, - {4}, - {5}, - {6}, - }, - }, + /* + Disabled because of https://github.com/dolthub/go-mysql-server/issues/2277 + { + q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);", + types: []plan.JoinType{plan.JoinTypeSemi}, + exp: []sql.Row{ + {1}, + {2}, + {3}, + {4}, + {5}, + }, + }, + */ { q: "select * from vals where exists (select * from ranges where val between min and max having val > 1);", types: []plan.JoinType{}, @@ -1327,8 +1330,8 @@ join uv d on d.u = c.x`, setup: []string{ "create table vals (val int)", "create table ranges (min int, max int)", - "insert into vals values (0), (1), (2), (3), (4), (5), (6)", - "insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6)", + "insert into vals values (null), (0), (1), (2), (3), (4), (5), (6)", + "insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6)", }, tests: []JoinPlanTest{ { @@ -1431,6 +1434,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join ranges on val > min and val < max", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil}, {0, nil, nil}, {1, 0, 2}, {2, 1, 3}, @@ -1454,6 +1458,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join ranges r1 on val > r1.min and val < r1.max left join ranges r2 on r1.min > r2.min and r1.min < r2.max", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap, plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil, nil, nil}, {0, nil, nil, nil, nil}, {1, 0, 2, nil, nil}, {2, 1, 3, 0, 2}, @@ -1495,6 +1500,7 @@ join uv d on d.u = c.x`, q: "select * from vals left join (select * from ranges where 0) as newRanges on val > min and val < max;", types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap}, exp: []sql.Row{ + {nil, nil, nil}, {0, nil, nil}, {1, nil, nil}, {2, nil, nil}, diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 95cc2a18b3..7e870ec246 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -8892,6 +8892,18 @@ from typestable`, {"abc"}, }, }, + { + Query: "select count(distinct cast(i as decimal)) from mytable;", + Expected: []sql.Row{ + {3}, + }, + }, + { + Query: "select count(distinct null);", + Expected: []sql.Row{ + {0}, + }, + }, } var KeylessQueries = []QueryTest{ diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index bab0bb0fff..77ef4494a7 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -25877,7 +25877,7 @@ order by x, y; " │ └─ Table\n" + " │ ├─ name: xy\n" + " │ └─ columns: [x y]\n" + - " └─ Sort(bigtable.n:1 ASC nullsLast)\n" + + " └─ Sort(bigtable.n:1 ASC nullsFirst)\n" + " └─ ProcessTable\n" + " └─ Table\n" + " ├─ name: bigtable\n" + diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 8930a68104..d299bf9801 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -5167,6 +5167,27 @@ CREATE TABLE tab3 ( }, }, }, + { + Name: "count distinct decimals", + SetUpScript: []string{ + "create table t (i int, j int)", + "insert into t values (1, 11), (11, 1)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select count(distinct i, j) from t;", + Expected: []sql.Row{ + {2}, + }, + }, + { + Query: "select count(distinct cast(i as decimal), cast(j as decimal)) from t;", + Expected: []sql.Row{ + {2}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/expression/function/aggregation/count_test.go b/sql/expression/function/aggregation/count_test.go index 30daf7bb19..7661db68cf 100644 --- a/sql/expression/function/aggregation/count_test.go +++ b/sql/expression/function/aggregation/count_test.go @@ -100,7 +100,7 @@ func TestCountDistinctEvalStar(t *testing.T) { require.NoError(b.Update(ctx, sql.NewRow(1))) require.NoError(b.Update(ctx, sql.NewRow(nil))) require.NoError(b.Update(ctx, sql.NewRow(1, 2, 3))) - require.Equal(int64(5), evalBuffer(t, b)) + require.Equal(int64(4), evalBuffer(t, b)) } func TestCountDistinctEvalString(t *testing.T) { diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index df0c4788f3..07729f68b4 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -4,7 +4,7 @@ import ( "fmt" "reflect" - "github.com/mitchellh/hashstructure" + "github.com/cespare/xxhash/v2" "github.com/shopspring/decimal" "github.com/dolthub/go-mysql-server/sql" @@ -402,7 +402,7 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error { if _, ok := c.exprs[0].(*expression.Star); ok { value = row } else { - val := make([]interface{}, len(c.exprs)) + val := make(sql.Row, len(c.exprs)) for i, expr := range c.exprs { v, err := expr.Eval(ctx, row) if err != nil { @@ -417,12 +417,30 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error { value = val } - hash, err := hashstructure.Hash(value, nil) - if err != nil { - return fmt.Errorf("count distinct unable to hash value: %s", err) + var str string + for _, val := range value.(sql.Row) { + // skip nil values + if val == nil { + return nil + } + v, _, err := types.Text.Convert(val) + if err != nil { + return err + } + vv, ok := v.(string) + if !ok { + return fmt.Errorf("count distinct unable to hash value: %s", err) + } + str += vv + "," } - c.seen[hash] = struct{}{} + hash := xxhash.New() + _, err := hash.WriteString(str) + if err != nil { + return err + } + h := hash.Sum64() + c.seen[h] = struct{}{} return nil } diff --git a/sql/memo/exec_builder.go b/sql/memo/exec_builder.go index 64c88d6002..14857908e1 100644 --- a/sql/memo/exec_builder.go +++ b/sql/memo/exec_builder.go @@ -100,7 +100,7 @@ func (b *ExecBuilder) buildRangeHeap(sr *RangeHeap, children ...sql.Node) (ret s sf := []sql.SortField{{ Column: sortExpr, Order: sql.Ascending, - NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903 + NullOrdering: sql.NullsFirst, }} childNode = plan.NewSort(sf, n) } @@ -135,7 +135,7 @@ func (b *ExecBuilder) buildRangeHeapJoin(j *RangeHeapJoin, children ...sql.Node) sf := []sql.SortField{{ Column: sortExpr, Order: sql.Ascending, - NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903 + NullOrdering: sql.NullsFirst, }} left = plan.NewSort(sf, children[0]) } diff --git a/sql/rowexec/range_heap_iter.go b/sql/rowexec/range_heap_iter.go index 283702fffc..08b09f81a0 100644 --- a/sql/rowexec/range_heap_iter.go +++ b/sql/rowexec/range_heap_iter.go @@ -215,9 +215,6 @@ func (iter *rangeHeapJoinIter) Close(ctx *sql.Context) (err error) { return err } -type rangeHeapRowIterProvider struct { -} - func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.NodeExecBuilder, primaryRow sql.Row) (err error) { iter.childRowIter, err = builder.Build(ctx, iter.rangeHeapPlan.Child, primaryRow) if err != nil { @@ -235,11 +232,10 @@ func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.Node } func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) { - // Remove rows from the heap if we've advanced beyond their max value. for iter.Len() > 0 { maxValue := iter.Peek() - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], maxValue) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], maxValue) if err != nil { return nil, err } @@ -258,7 +254,7 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB // Advance the child iterator until we encounter a row whose min value is beyond the range. for iter.pendingRow != nil { minValue := iter.pendingRow[iter.rangeHeapPlan.MinColumnIndex] - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], minValue) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], minValue) if err != nil { return nil, err } @@ -289,13 +285,31 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB return sql.RowsToRowIter(iter.activeRanges...), nil } +// When managing the heap, consider all NULLs to come before any non-NULLS. +// This is consistent with the order received if either child node is an index. +// Note: We could get the same behavior by simply excluding values and ranges containing NULL, +// but this is forward compatible if we ever want to convert joins with null-safe conditions into RangeHeapJoins. +func compareNullsFirst(comparisonType sql.Type, a, b interface{}) (int, error) { + if a == nil { + if b == nil { + return 0, nil + } else { + return -1, nil + } + } + if b == nil { + return 1, nil + } + return comparisonType.Compare(a, b) +} + func (iter rangeHeapJoinIter) Len() int { return len(iter.activeRanges) } func (iter *rangeHeapJoinIter) Less(i, j int) bool { lhs := iter.activeRanges[i][iter.rangeHeapPlan.MaxColumnIndex] rhs := iter.activeRanges[j][iter.rangeHeapPlan.MaxColumnIndex] // compareResult will be 0 if lhs==rhs, -1 if lhs < rhs, and +1 if lhs > rhs. - compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(lhs, rhs) + compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, lhs, rhs) if iter.err == nil && err != nil { iter.err = err }