Skip to content

Commit

Permalink
Merge joins populate join stats (#2247)
Browse files Browse the repository at this point in the history
* Statistics for merge joins

* Join Stats

Joins with a merge join option and table statistics for
both merge join indexes will use the histogram merging
logic for estimating join cardinality.

* more costing tweaks

* costing fixes

* tighten distributions, comments

* nick comments
  • Loading branch information
max-hoffman authored Jan 20, 2024
1 parent b80ed6f commit f795b05
Show file tree
Hide file tree
Showing 39 changed files with 36,908 additions and 43,315 deletions.
32 changes: 1 addition & 31 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ func TestTableFunctions(t *testing.T) {
harness.Setup(setup.MydbData)

databaseProvider := harness.NewDatabaseProvider()
testDatabaseProvider := NewTestProvider(
testDatabaseProvider := enginetest.NewTestProvider(
&databaseProvider,
SimpleTableFunction{},
memory.IntSequenceTable{},
Expand Down Expand Up @@ -943,36 +943,6 @@ func (itr *SimpleTableFunctionRowIter) Close(_ *sql.Context) error {
return nil
}

var _ sql.FunctionProvider = (*TestProvider)(nil)

type TestProvider struct {
sql.MutableDatabaseProvider
tableFunctions map[string]sql.TableFunction
}

func NewTestProvider(dbProvider *sql.MutableDatabaseProvider, tf ...sql.TableFunction) *TestProvider {
tfs := make(map[string]sql.TableFunction)
for _, tf := range tf {
tfs[strings.ToLower(tf.Name())] = tf
}
return &TestProvider{
*dbProvider,
tfs,
}
}

func (t TestProvider) Function(_ *sql.Context, name string) (sql.Function, error) {
return nil, sql.ErrFunctionNotFound.New(name)
}

func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok {
return tf, nil
}

return nil, sql.ErrTableFunctionNotFound.New(name)
}

func TestTimestampBindingsCanBeConverted(t *testing.T) {
db, close := newDatabase()
defer close()
Expand Down
3 changes: 0 additions & 3 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5258,9 +5258,6 @@ func TestIndexes(t *testing.T, h Harness) {
}

func TestIndexPrefix(t *testing.T, h Harness) {
e := mustNewEngine(t, h)
defer e.Close()

for _, tt := range queries.IndexPrefixQueries {
TestScript(t, h, tt)
}
Expand Down
5 changes: 3 additions & 2 deletions enginetest/histogram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug
rStat.Hist = append(rStat.Hist, b.(*stats.Bucket))
}

res, err := stats.Join(stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), []int{0}, []int{0}, debug)
res, err := stats.Join(stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), 1, debug)
require.NoError(t, err)
if debug {
log.Printf("join %s\n", res.Histogram().DebugString())
Expand All @@ -304,6 +304,7 @@ func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug
if debug {
log.Println(res.RowCount(), exp, delta)
}

// This compares the error percentage for our estimate to an
// error threshold specified in the statTest. The error bounds
// are loose and mostly useful for debugging at this point.
Expand Down Expand Up @@ -508,7 +509,7 @@ func normalDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int, mean,
break
}
row := sql.Row{int64(i)}
for _, v := range val {
for _, v := range val[1:] {
row = append(row, int64(v.(float64)))
}
err = tab.Insert(ctx, row)
Expand Down
6 changes: 5 additions & 1 deletion enginetest/join_op_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -57,7 +59,9 @@ func TestJoinOps(t *testing.T, harness Harness, tests []joinOpTest) {
}

if pro, ok := e.EngineAnalyzer().Catalog.DbProvider.(*memory.DbProvider); ok {
e.EngineAnalyzer().Catalog.DbProvider = pro.WithTableFunctions([]sql.TableFunction{memory.RequiredLookupTable{}})
newProv, err := pro.WithTableFunctions(memory.RequiredLookupTable{})
require.NoError(t, err)
e.EngineAnalyzer().Catalog.DbProvider = newProv.(sql.DatabaseProvider)
}

for k, c := range biasedCosters {
Expand Down
21 changes: 11 additions & 10 deletions enginetest/join_planning_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ var JoinPlanningTests = []struct {
},
{
q: "select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = s-1 order by 1, 3",
types: []plan.JoinType{plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeLookup},
exp: []sql.Row{{4, 4, 3, 3}, {5, 4, 3, 3}},
},
//{
Expand Down Expand Up @@ -185,6 +185,7 @@ var JoinPlanningTests = []struct {
},
},
{
// todo: rewrite implementing new stats interface
name: "merge join large and small table",
setup: []string{
"CREATE table xy (x int primary key, y int, index y_idx(y));",
Expand Down Expand Up @@ -403,13 +404,13 @@ order by 1;`,
},
{
q: "select * from xy where y-1 in (select u from uv) order by 1;",
types: []plan.JoinType{plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeSemiLookup},
exp: []sql.Row{{0, 2}, {2, 1}, {3, 3}},
},
{
// semi join will be right-side, be passed non-nil parent row
q: "select x,a from ab, (select * from xy where x = (select r from rs where r = 1) order by 1) sq order by 1,2",
types: []plan.JoinType{plan.JoinTypeCrossHash, plan.JoinTypeLookup},
types: []plan.JoinType{plan.JoinTypeCrossHash, plan.JoinTypeMerge},
exp: []sql.Row{{1, 0}, {1, 1}, {1, 2}, {1, 3}},
},
//{
Expand All @@ -435,7 +436,7 @@ order by 1;`,
},
{
q: "select * from xy where y-1 in (select u from uv order by 1) order by 1;",
types: []plan.JoinType{plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeSemiLookup},
exp: []sql.Row{{0, 2}, {2, 1}, {3, 3}},
},
{
Expand All @@ -445,7 +446,7 @@ order by 1;`,
},
{
q: "select * from xy where x in (select u from uv join ab on u = a and a = 2) order by 1;",
types: []plan.JoinType{plan.JoinTypeLookup, plan.JoinTypeMerge},
types: []plan.JoinType{plan.JoinTypeHash, plan.JoinTypeMerge},
exp: []sql.Row{{2, 1}},
},
{
Expand Down Expand Up @@ -533,7 +534,7 @@ HAVING count(v) >= 1)`,
},
{
q: "select * from xy where x in (select cnt from (select count(u) as cnt from uv group by v having cnt > 0) sq) order by 1,2;",
types: []plan.JoinType{plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeLookup},
exp: []sql.Row{{2, 1}},
},
{
Expand Down Expand Up @@ -575,7 +576,7 @@ WHERE EXISTS (
select x from xy where
not exists (select a from ab where a = x and a = 1) and
not exists (select a from ab where a = x and a = 2)`,
types: []plan.JoinType{plan.JoinTypeLeftOuterLookup, plan.JoinTypeLeftOuterMerge},
types: []plan.JoinType{plan.JoinTypeLeftOuterHashExcludeNulls, plan.JoinTypeLeftOuterMerge},
exp: []sql.Row{{0}, {3}},
},
{
Expand All @@ -586,7 +587,7 @@ select * from xy where x in (
)
SELECT u FROM uv, tree where u = s
)`,
types: []plan.JoinType{plan.JoinTypeHash, plan.JoinTypeHash},
types: []plan.JoinType{plan.JoinTypeLookup, plan.JoinTypeHash},
exp: []sql.Row{{1, 0}},
},
{
Expand Down Expand Up @@ -750,7 +751,7 @@ where u in (select * from rec);`,
tests: []JoinPlanTest{
{
q: "select * from xy where x in (select u from uv join ab on u = a and a = 2) order by 1;",
types: []plan.JoinType{plan.JoinTypeHash, plan.JoinTypeMerge},
types: []plan.JoinType{plan.JoinTypeLookup, plan.JoinTypeMerge},
exp: []sql.Row{{2, 1}},
},
{
Expand Down Expand Up @@ -1007,7 +1008,7 @@ join uv d on d.u = c.x`,
},
{
q: "select /*+ LOOKUP_JOIN(s,uv) */ 1 from xy s where x in (select u from uv)",
types: []plan.JoinType{plan.JoinTypeLookup},
types: []plan.JoinType{plan.JoinTypeSemiLookup},
},
{
q: "select /*+ SEMI_JOIN(s,uv) */ 1 from xy s where x in (select u from uv)",
Expand Down
171 changes: 171 additions & 0 deletions enginetest/join_stats_tests.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package enginetest

import (
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
)

func TestJoinStats(t *testing.T, harness Harness) {
harness.Setup(setup.MydbData)

for _, tt := range JoinStatTests {
t.Run(tt.name, func(t *testing.T) {
harness.Setup([]setup.SetupScript{setup.MydbData[0]})
e := mustNewEngine(t, harness)
defer e.Close()

tfp, ok := e.EngineAnalyzer().Catalog.DbProvider.(sql.TableFunctionProvider)
if !ok {
return
}
newPro, err := tfp.WithTableFunctions(memory.ExponentialDistTable{}, memory.NormalDistTable{})
require.NoError(t, err)
e.EngineAnalyzer().Catalog.DbProvider = newPro.(sql.DatabaseProvider)

ctx := harness.NewContext()
for _, q := range tt.setup {
_, iter, err := e.Query(ctx, q)
require.NoError(t, err)
_, err = sql.RowIterToRows(ctx, iter)
require.NoError(t, err)
}

for _, tt := range tt.tests {
if tt.order != nil {
evalJoinOrder(t, harness, e, tt.q, tt.order, tt.skipOld)
}
if tt.exp != nil {
evalJoinCorrectness(t, harness, e, tt.q, tt.q, tt.exp, false)
}
}
})
}
}

var JoinStatTests = []struct {
name string
setup []string
tests []JoinPlanTest
}{
{
name: "test table orders with normal distributions",
setup: []string{
"create table u0 (a int primary key, b int, c int, key (b,c))",
"insert into u0 select * from normal_dist(2, 500, 0, 5)",
"create table u0_2 (a int primary key, b int, c int, key (b,c))",
"insert into u0_2 select * from normal_dist(2, 2000, 0, 5)",
"create table `u-15` (a int primary key, b int, c int, key (b,c))",
"insert into `u-15` select * from normal_dist(2, 3000, -15, 5)",
"create table `u+15` (a int primary key, b int, c int, key (b,c))",
"insert into `u+15` select * from normal_dist(2, 4000, 15, 5)",
"analyze table u0",
"analyze table u0_2",
"analyze table `u-15`",
"analyze table `u+15`",
},
tests: []JoinPlanTest{
{
// a is smaller
q: "select /*+ LEFT_DEEP */ count(*) from `u-15` a join `u+15` b on a.b = b.b",
order: []string{"a", "b"},
},
{
// b with filter is smaller
q: "select /*+ LEFT_DEEP */ count(*) from `u-15` a join `u+15` b on a.b = b.b where b.b < 15",
order: []string{"b", "a"},
},
{
// a < c < b, axc is smallest join
q: "select /*+ LEFT_DEEP */ count(*) from `u-15` a join u0_2 b on a.b = b.b join `u+15` c on a.b = c.b where a.b > -15 and c.b < 15",
order: []string{"a", "c", "b"},
},
},
},
{
// there is a trade-off for these where we either pick the first table
// first if card(b) < card(axc), or we choose (axc) if its intermediate
// result cardinality is smaller than filtered (b).
name: "test table orders with filters and normal distributions",
setup: []string{
"create table u0 (a int primary key, b int, c int, key (b,c))",
"insert into u0 select * from normal_dist(2, 2000, 0, 5)",
"create table u0_2 (a int primary key, b int, c int, key (b,c))",
"insert into u0_2 select * from normal_dist(2, 2000, 0, 5)",
"create table `u-15` (a int primary key, b int, c int, key (b,c))",
"insert into `u-15` select * from normal_dist(2, 2000, -10, 5)",
"create table `u+15` (a int primary key, b int, c int, key (b,c))",
"insert into `u+15` select * from normal_dist(2, 2000, 10, 5)",
"analyze table u0",
"analyze table u0_2",
"analyze table `u-15`",
"analyze table `u+15`",
},
tests: []JoinPlanTest{
{
// axc is smallest join, a is smallest table
q: "select /*+ LEFT_DEEP */ count(*) from u0 b join `u-15` a on a.b = b.b join `u+15` c on a.b = c.b where a.b > 2",
order: []string{"a", "c", "b"},
},
{
// b is smallest table, bxc is smallest b-connected join
// due to b < 0 filter and positive c skew
q: "select /*+ LEFT_DEEP */ count(*) from u0 b join `u-15` a on a.b = b.b join `u+15` c on a.b = c.b where b.b < -2",
order: []string{"b", "c", "a"},
},
{
q: "select /*+ LEFT_DEEP */ count(*) from u0 b join `u-15` a on a.b = b.b join `u+15` c on a.b = c.b where b.b < -2",
order: []string{"b", "c", "a"},
},
{
// b is smallest table, bxa is smallest b-connected join
// due to b > 0 filter and negative c skew
q: "select /*+ LEFT_DEEP */ count(*) from `u-15` a join u0 b on a.b = b.b join `u+15` c on a.b = c.b where b.b > 2",
order: []string{"b", "a", "c"},
},
{
q: "select /*+ LEFT_DEEP */ count(*) from u0 b join `u-15` a on a.b = b.b join `u+15` c on a.b = c.b where b.b > 2",
order: []string{"b", "a", "c"},
},
},
},
}

func NewTestProvider(dbProvider *sql.MutableDatabaseProvider, tf ...sql.TableFunction) *TestProvider {
tfs := make(map[string]sql.TableFunction)
for _, tf := range tf {
tfs[strings.ToLower(tf.Name())] = tf
}
return &TestProvider{
*dbProvider,
tfs,
}
}

var _ sql.FunctionProvider = (*TestProvider)(nil)

type TestProvider struct {
sql.MutableDatabaseProvider
tableFunctions map[string]sql.TableFunction
}

func (t TestProvider) Function(_ *sql.Context, name string) (sql.Function, error) {
return nil, sql.ErrFunctionNotFound.New(name)
}

func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok {
return tf, nil
}

return nil, sql.ErrTableFunctionNotFound.New(name)
}

func (t TestProvider) WithTableFunctions(fns ...sql.TableFunction) (sql.TableFunctionProvider, error) {
return t, nil
}
4 changes: 4 additions & 0 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func TestJoinOps(t *testing.T) {
enginetest.TestJoinOps(t, enginetest.NewDefaultMemoryHarness(), enginetest.DefaultJoinOpTests)
}

func TestJoinStats(t *testing.T) {
enginetest.TestJoinStats(t, enginetest.NewDefaultMemoryHarness())
}

// TestJSONTableQueries runs the canonical test queries against a single threaded index enabled harness.
func TestJSONTableQueries(t *testing.T) {
enginetest.TestJSONTableQueries(t, enginetest.NewDefaultMemoryHarness())
Expand Down
Loading

0 comments on commit f795b05

Please sign in to comment.