Skip to content

Commit

Permalink
Merge pull request #2463 from dolthub/nicktobey/insert-as
Browse files Browse the repository at this point in the history
Update GMS to detect INSERT statements with row alias and return error.
  • Loading branch information
nicktobey authored Apr 16, 2024
2 parents 436ceb5 + f43f8e3 commit cf70da4
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 34 deletions.
18 changes: 14 additions & 4 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) {
} {
for _, tt := range querySet {
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzererrors.ErrReadOnlyDatabase)
})
}
Expand Down Expand Up @@ -1235,10 +1239,16 @@ func TestDelete(t *testing.T, harness Harness) {
for name, coster := range biasedCosters {
t.Run(name+" join", func(t *testing.T) {
for _, tt := range queries.DeleteJoinTests {
e := mustNewEngine(t, harness)
e.EngineAnalyzer().Coster = coster
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
e := mustNewEngine(t, harness)
e.EngineAnalyzer().Coster = coster
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
})
}
})
}
Expand Down
48 changes: 28 additions & 20 deletions enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,37 +1046,45 @@ func ExtractQueryNode(node sql.Node) sql.Node {

// RunWriteQueryTest runs the specified |tt| WriteQueryTest using the specified harness.
func RunWriteQueryTest(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
e := mustNewEngine(t, harness)
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
e := mustNewEngine(t, harness)
defer e.Close()
RunWriteQueryTestWithEngine(t, harness, e, tt)
})
}

// RunWriteQueryTestWithEngine runs the specified |tt| WriteQueryTest, using the specified harness and engine. Callers
// are still responsible for closing the engine.
func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.WriteQueryTest) {
t.Run(tt.WriteQuery, func(t *testing.T) {
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
return
}
if sh.SkipQueryTest(tt.SelectQuery) {
t.Logf("Skipping query %s", tt.SelectQuery)
return
}
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
return
}
ctx := NewContext(harness)
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
expectedSelect := tt.ExpectedSelect
if IsServerEngine(e) && tt.SkipServerEngine {
expectedSelect = nil
if sh.SkipQueryTest(tt.SelectQuery) {
t.Logf("Skipping query %s", tt.SelectQuery)
return
}
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
})
}
ctx := NewContext(harness)
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
expectedSelect := tt.ExpectedSelect
if IsServerEngine(e) && tt.SkipServerEngine {
expectedSelect = nil
}
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
}

func runWriteQueryTestPrepared(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
t.Skip()
return
}
if sh, ok := harness.(SkippingHarness); ok {
if sh.SkipQueryTest(tt.WriteQuery) {
t.Logf("Skipping query %s", tt.WriteQuery)
Expand Down
26 changes: 26 additions & 0 deletions enginetest/queries/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,20 @@ var InsertQueries = []WriteQueryTest{
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
},
{
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt(new_i,new_s) ON DUPLICATE KEY UPDATE s=new_s",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
Skip: true, // https://github.com/dolthub/dolt/issues/7638
},
{
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt ON DUPLICATE KEY UPDATE mytable.s=dt.s",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
ExpectedSelect: []sql.Row{{int64(1), "hir"}},
Skip: true, // https://github.com/dolthub/dolt/issues/7638
},
{
WriteQuery: "INSERT INTO mytable (s,i) values ('dup',1) ON DUPLICATE KEY UPDATE s=CONCAT(VALUES(s), 'licate')",
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
Expand Down Expand Up @@ -1881,6 +1895,18 @@ var InsertScripts = []ScriptTest{
},
},
},
{
Name: "insert on duplicate key with incorrect row alias",
SetUpScript: []string{
`create table a (i int primary key)`,
},
Assertions: []ScriptTestAssertion{
{
Query: `insert into a values (1) as new(c, d) on duplicate key update i = c`,
ExpectedErr: sql.ErrColumnCountMismatch,
},
},
},
{
Name: "Insert throws primary key violations",
SetUpScript: []string{
Expand Down
1 change: 1 addition & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -10763,6 +10763,7 @@ type WriteQueryTest struct {
SelectQuery string
ExpectedSelect []sql.Row
Bindings map[string]*query.BindVariable
Skip bool
SkipServerEngine bool
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df h1:hXB89Qhyu0ymVhP4AvuCtHWGpQmZN0Tt5Cc58Ig8/dg=
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80 h1:BG7DheiFrbvKYtPmZ1avXA/VPKzz+Bv7L0ytUi83kyQ=
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
29 changes: 22 additions & 7 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package planbuilder

import (
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -128,7 +129,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
switch v := ir.(type) {
case ast.SelectStatement:
return b.buildSelectStmt(inScope, v)
case ast.Values:
case *ast.AliasedValues:
outScope = b.buildInsertValues(inScope, v, columnNames, tableName, destSchema)
default:
err := sql.ErrUnsupportedSyntax.New(ast.String(ir))
Expand All @@ -137,7 +138,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
return
}

func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
func (b *Builder) buildInsertValues(inScope *scope, v *ast.AliasedValues, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
columnDefaultValues := make([]*sql.ColumnDefaultValue, len(columnNames))

for i, columnName := range columnNames {
Expand All @@ -156,8 +157,22 @@ func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []
}
}

exprTuples := make([][]sql.Expression, len(v))
for i, vt := range v {
if !v.As.IsEmpty() {
if len(v.Columns) != 0 {
for _, tuple := range v.Values {
if len(v.Columns) != len(tuple) {
err := sql.ErrColumnCountMismatch.New()
b.handleErr(err)
}
}

err := errors.New("insert row aliases are not currently supported; use the VALUES() function instead")
b.handleErr(err)
}
}

exprTuples := make([][]sql.Expression, len(v.Values))
for i, vt := range v.Values {
// noExprs is an edge case where we fill VALUES with nil expressions
noExprs := len(vt) == 0
// triggerUnknownTable is an edge case where we ignored an unresolved
Expand Down Expand Up @@ -217,10 +232,10 @@ func reorderSchema(names []string, schema sql.Schema) sql.Schema {
return newSch
}

func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) {
func (b *Builder) buildValues(inScope *scope, v ast.AliasedValues) (outScope *scope) {
// TODO add literals to outScope?
exprTuples := make([][]sql.Expression, len(v))
for i, vt := range v {
exprTuples := make([][]sql.Expression, len(v.Values))
for i, vt := range v.Values {
exprs := make([]sql.Expression, len(vt))
exprTuples[i] = exprs
for j, e := range vt {
Expand Down

0 comments on commit cf70da4

Please sign in to comment.