Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman committed Jan 20, 2024
2 parents 7ce96c8 + b80ed6f commit 70145a4
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 37 deletions.
46 changes: 26 additions & 20 deletions enginetest/join_planning_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -1061,12 +1061,12 @@ join uv d on d.u = c.x`,
},
},
{
name: "primary key range join",
name: "indexed range join",
setup: []string{
"create table vals (val int primary key);",
"create table ranges (min int primary key, max int, unique key(min,max));",
"insert into vals values (0), (1), (2), (3), (4), (5), (6);",
"insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6);",
"create table vals (val int unique key);",
"create table ranges (min int unique key, max int, unique key(min,max));",
"insert into vals values (null), (0), (1), (2), (3), (4), (5), (6);",
"insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6);",
},
tests: []JoinPlanTest{
{
Expand Down Expand Up @@ -1247,8 +1247,9 @@ join uv d on d.u = c.x`,
},
{
q: "select * from vals where exists (select * from vals join ranges on val between min and max where min >= 2 and max <= 5)",
types: []plan.JoinType{plan.JoinTypeCross, plan.JoinTypeInner},
types: []plan.JoinType{plan.JoinTypeCrossHash, plan.JoinTypeInner},
exp: []sql.Row{
{nil},
{0},
{1},
{2},
Expand Down Expand Up @@ -1297,18 +1298,20 @@ join uv d on d.u = c.x`,
{6},
},
},
{
q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);",
types: []plan.JoinType{plan.JoinTypeSemi},
exp: []sql.Row{
{1},
{2},
{3},
{4},
{5},
{6},
},
},
/*
Disabled because of https://github.com/dolthub/go-mysql-server/issues/2277
{
q: "select * from vals where exists (select * from ranges where val between min and max limit 1 offset 1);",
types: []plan.JoinType{plan.JoinTypeSemi},
exp: []sql.Row{
{1},
{2},
{3},
{4},
{5},
},
},
*/
{
q: "select * from vals where exists (select * from ranges where val between min and max having val > 1);",
types: []plan.JoinType{},
Expand All @@ -1327,8 +1330,8 @@ join uv d on d.u = c.x`,
setup: []string{
"create table vals (val int)",
"create table ranges (min int, max int)",
"insert into vals values (0), (1), (2), (3), (4), (5), (6)",
"insert into ranges values (0,2), (1,3), (2,4), (3,5), (4,6)",
"insert into vals values (null), (0), (1), (2), (3), (4), (5), (6)",
"insert into ranges values (null,1), (0,2), (1,3), (2,4), (3,5), (4,6)",
},
tests: []JoinPlanTest{
{
Expand Down Expand Up @@ -1431,6 +1434,7 @@ join uv d on d.u = c.x`,
q: "select * from vals left join ranges on val > min and val < max",
types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap},
exp: []sql.Row{
{nil, nil, nil},
{0, nil, nil},
{1, 0, 2},
{2, 1, 3},
Expand All @@ -1454,6 +1458,7 @@ join uv d on d.u = c.x`,
q: "select * from vals left join ranges r1 on val > r1.min and val < r1.max left join ranges r2 on r1.min > r2.min and r1.min < r2.max",
types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap, plan.JoinTypeLeftOuterRangeHeap},
exp: []sql.Row{
{nil, nil, nil, nil, nil},
{0, nil, nil, nil, nil},
{1, 0, 2, nil, nil},
{2, 1, 3, 0, 2},
Expand Down Expand Up @@ -1495,6 +1500,7 @@ join uv d on d.u = c.x`,
q: "select * from vals left join (select * from ranges where 0) as newRanges on val > min and val < max;",
types: []plan.JoinType{plan.JoinTypeLeftOuterRangeHeap},
exp: []sql.Row{
{nil, nil, nil},
{0, nil, nil},
{1, nil, nil},
{2, nil, nil},
Expand Down
12 changes: 12 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8892,6 +8892,18 @@ from typestable`,
{"abc"},
},
},
{
Query: "select count(distinct cast(i as decimal)) from mytable;",
Expected: []sql.Row{
{3},
},
},
{
Query: "select count(distinct null);",
Expected: []sql.Row{
{0},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
2 changes: 1 addition & 1 deletion enginetest/queries/query_plans.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -5167,6 +5167,27 @@ CREATE TABLE tab3 (
},
},
},
{
Name: "count distinct decimals",
SetUpScript: []string{
"create table t (i int, j int)",
"insert into t values (1, 11), (11, 1)",
},
Assertions: []ScriptTestAssertion{
{
Query: "select count(distinct i, j) from t;",
Expected: []sql.Row{
{2},
},
},
{
Query: "select count(distinct cast(i as decimal), cast(j as decimal)) from t;",
Expected: []sql.Row{
{2},
},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/aggregation/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestCountDistinctEvalStar(t *testing.T) {
require.NoError(b.Update(ctx, sql.NewRow(1)))
require.NoError(b.Update(ctx, sql.NewRow(nil)))
require.NoError(b.Update(ctx, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), evalBuffer(t, b))
require.Equal(int64(4), evalBuffer(t, b))
}

func TestCountDistinctEvalString(t *testing.T) {
Expand Down
30 changes: 24 additions & 6 deletions sql/expression/function/aggregation/unary_agg_buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
"reflect"

"github.com/mitchellh/hashstructure"
"github.com/cespare/xxhash/v2"
"github.com/shopspring/decimal"

"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -402,7 +402,7 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error {
if _, ok := c.exprs[0].(*expression.Star); ok {
value = row
} else {
val := make([]interface{}, len(c.exprs))
val := make(sql.Row, len(c.exprs))
for i, expr := range c.exprs {
v, err := expr.Eval(ctx, row)
if err != nil {
Expand All @@ -417,12 +417,30 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error {
value = val
}

hash, err := hashstructure.Hash(value, nil)
if err != nil {
return fmt.Errorf("count distinct unable to hash value: %s", err)
var str string
for _, val := range value.(sql.Row) {
// skip nil values
if val == nil {
return nil
}
v, _, err := types.Text.Convert(val)
if err != nil {
return err
}
vv, ok := v.(string)
if !ok {
return fmt.Errorf("count distinct unable to hash value: %s", err)
}
str += vv + ","
}

c.seen[hash] = struct{}{}
hash := xxhash.New()
_, err := hash.WriteString(str)
if err != nil {
return err
}
h := hash.Sum64()
c.seen[h] = struct{}{}

return nil
}
Expand Down
4 changes: 2 additions & 2 deletions sql/memo/exec_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (b *ExecBuilder) buildRangeHeap(sr *RangeHeap, children ...sql.Node) (ret s
sf := []sql.SortField{{
Column: sortExpr,
Order: sql.Ascending,
NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903
NullOrdering: sql.NullsFirst,
}}
childNode = plan.NewSort(sf, n)
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func (b *ExecBuilder) buildRangeHeapJoin(j *RangeHeapJoin, children ...sql.Node)
sf := []sql.SortField{{
Column: sortExpr,
Order: sql.Ascending,
NullOrdering: sql.NullsLast, // Due to https://github.com/dolthub/go-mysql-server/issues/1903
NullOrdering: sql.NullsFirst,
}}
left = plan.NewSort(sf, children[0])
}
Expand Down
28 changes: 21 additions & 7 deletions sql/rowexec/range_heap_iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ func (iter *rangeHeapJoinIter) Close(ctx *sql.Context) (err error) {
return err
}

type rangeHeapRowIterProvider struct {
}

func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.NodeExecBuilder, primaryRow sql.Row) (err error) {
iter.childRowIter, err = builder.Build(ctx, iter.rangeHeapPlan.Child, primaryRow)
if err != nil {
Expand All @@ -235,11 +232,10 @@ func (iter *rangeHeapJoinIter) initializeHeap(ctx *sql.Context, builder sql.Node
}

func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) {

// Remove rows from the heap if we've advanced beyond their max value.
for iter.Len() > 0 {
maxValue := iter.Peek()
compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], maxValue)
compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], maxValue)
if err != nil {
return nil, err
}
Expand All @@ -258,7 +254,7 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB
// Advance the child iterator until we encounter a row whose min value is beyond the range.
for iter.pendingRow != nil {
minValue := iter.pendingRow[iter.rangeHeapPlan.MinColumnIndex]
compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(row[iter.rangeHeapPlan.ValueColumnIndex], minValue)
compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, row[iter.rangeHeapPlan.ValueColumnIndex], minValue)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -289,13 +285,31 @@ func (iter *rangeHeapJoinIter) getActiveRanges(ctx *sql.Context, _ sql.NodeExecB
return sql.RowsToRowIter(iter.activeRanges...), nil
}

// When managing the heap, consider all NULLs to come before any non-NULLS.
// This is consistent with the order received if either child node is an index.
// Note: We could get the same behavior by simply excluding values and ranges containing NULL,
// but this is forward compatible if we ever want to convert joins with null-safe conditions into RangeHeapJoins.
func compareNullsFirst(comparisonType sql.Type, a, b interface{}) (int, error) {
if a == nil {
if b == nil {
return 0, nil
} else {
return -1, nil
}
}
if b == nil {
return 1, nil
}
return comparisonType.Compare(a, b)
}

func (iter rangeHeapJoinIter) Len() int { return len(iter.activeRanges) }

func (iter *rangeHeapJoinIter) Less(i, j int) bool {
lhs := iter.activeRanges[i][iter.rangeHeapPlan.MaxColumnIndex]
rhs := iter.activeRanges[j][iter.rangeHeapPlan.MaxColumnIndex]
// compareResult will be 0 if lhs==rhs, -1 if lhs < rhs, and +1 if lhs > rhs.
compareResult, err := iter.rangeHeapPlan.ComparisonType.Compare(lhs, rhs)
compareResult, err := compareNullsFirst(iter.rangeHeapPlan.ComparisonType, lhs, rhs)
if iter.err == nil && err != nil {
iter.err = err
}
Expand Down

0 comments on commit 70145a4

Please sign in to comment.