Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GMS to detect INSERT statements with row alias and return error. #2463

Merged
merged 7 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) {
} {
for _, tt := range querySet {
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzererrors.ErrReadOnlyDatabase)
})
}
Expand Down Expand Up @@ -1235,10 +1239,16 @@ func TestDelete(t *testing.T, harness Harness) {
for name, coster := range biasedCosters {
t.Run(name+" join", func(t *testing.T) {
for _, tt := range queries.DeleteJoinTests {
e := mustNewEngine(t, harness)
e.EngineAnalyzer().Coster = coster
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
e := mustNewEngine(t, harness)
e.EngineAnalyzer().Coster = coster
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
})
}
})
}
Expand Down
48 changes: 28 additions & 20 deletions enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,37 +1046,45 @@ func ExtractQueryNode(node sql.Node) sql.Node {

// RunWriteQueryTest runs the specified |tt| WriteQueryTest using the specified harness.
func RunWriteQueryTest(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
e := mustNewEngine(t, harness)
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
e := mustNewEngine(t, harness)
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
})
}

// RunWriteQueryTestWithEngine runs the specified |tt| WriteQueryTest, using the specified harness and engine. Callers
// are still responsible for closing the engine.
func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.WriteQueryTest) {
t.Run(tt.WriteQuery, func(t *testing.T) {
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
return
}
if sh.SkipQueryTest(tt.SelectQuery) {
t.Logf("Skipping query %s", tt.SelectQuery)
return
}
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
return
}
ctx := NewContext(harness)
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
expectedSelect := tt.ExpectedSelect
if IsServerEngine(e) && tt.SkipServerEngine {
expectedSelect = nil
if sh.SkipQueryTest(tt.SelectQuery) {
t.Logf("Skipping query %s", tt.SelectQuery)
return
}
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
})
}
ctx := NewContext(harness)
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
expectedSelect := tt.ExpectedSelect
if IsServerEngine(e) && tt.SkipServerEngine {
expectedSelect = nil
}
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
}

func runWriteQueryTestPrepared(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
Expand Down
26 changes: 26 additions & 0 deletions enginetest/queries/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,20 @@ var InsertQueries = []WriteQueryTest{
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
},
{
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt(new_i,new_s) ON DUPLICATE KEY UPDATE s=new_s",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
Skip: true, // https://github.com/dolthub/dolt/issues/7638
},
{
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt ON DUPLICATE KEY UPDATE mytable.s=dt.s",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hir"}},
Skip: true, // https://github.com/dolthub/dolt/issues/7638
},
{
WriteQuery: "INSERT INTO mytable (s,i) values ('dup',1) ON DUPLICATE KEY UPDATE s=CONCAT(VALUES(s), 'licate')",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
Expand Down Expand Up @@ -1881,6 +1895,18 @@ var InsertScripts = []ScriptTest{
},
},
},
{
Name: "insert on duplicate key with incorrect row alias",
SetUpScript: []string{
`create table a (i int primary key)`,
},
Assertions: []ScriptTestAssertion{
{
Query: `insert into a values (1) as new(c, d) on duplicate key update i = c`,
ExpectedErr: sql.ErrColumnCountMismatch,
},
},
},
{
Name: "Insert throws primary key violations",
SetUpScript: []string{
Expand Down
1 change: 1 addition & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -10763,6 +10763,7 @@ type WriteQueryTest struct {
SelectQuery string
ExpectedSelect []sql.Row
Bindings map[string]*query.BindVariable
Skip bool
SkipServerEngine bool
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df h1:hXB89Qhyu0ymVhP4AvuCtHWGpQmZN0Tt5Cc58Ig8/dg=
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80 h1:BG7DheiFrbvKYtPmZ1avXA/VPKzz+Bv7L0ytUi83kyQ=
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
29 changes: 22 additions & 7 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package planbuilder

import (
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -128,7 +129,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
switch v := ir.(type) {
case ast.SelectStatement:
return b.buildSelectStmt(inScope, v)
case ast.Values:
case *ast.AliasedValues:
outScope = b.buildInsertValues(inScope, v, columnNames, tableName, destSchema)
default:
err := sql.ErrUnsupportedSyntax.New(ast.String(ir))
Expand All @@ -137,7 +138,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
return
}

func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
func (b *Builder) buildInsertValues(inScope *scope, v *ast.AliasedValues, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
columnDefaultValues := make([]*sql.ColumnDefaultValue, len(columnNames))

for i, columnName := range columnNames {
Expand All @@ -156,8 +157,22 @@ func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []
}
}

exprTuples := make([][]sql.Expression, len(v))
for i, vt := range v {
if !v.As.IsEmpty() {
if len(v.Columns) != 0 {
for _, tuple := range v.Values {
if len(v.Columns) != len(tuple) {
err := sql.ErrColumnCountMismatch.New()
b.handleErr(err)
}
}

err := errors.New("insert row aliases are not currently supported; use the VALUES() function instead")
b.handleErr(err)
}
}

exprTuples := make([][]sql.Expression, len(v.Values))
for i, vt := range v.Values {
// noExprs is an edge case where we fill VALUES with nil expressions
noExprs := len(vt) == 0
// triggerUnknownTable is an edge case where we ignored an unresolved
Expand Down Expand Up @@ -217,10 +232,10 @@ func reorderSchema(names []string, schema sql.Schema) sql.Schema {
return newSch
}

func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) {
func (b *Builder) buildValues(inScope *scope, v ast.AliasedValues) (outScope *scope) {
// TODO add literals to outScope?
exprTuples := make([][]sql.Expression, len(v))
for i, vt := range v {
exprTuples := make([][]sql.Expression, len(v.Values))
for i, vt := range v.Values {
exprs := make([]sql.Expression, len(vt))
exprTuples[i] = exprs
for j, e := range vt {
Expand Down