diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 45a72f7d96..bb4a752a7b 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -367,6 +367,10 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) { } { for _, tt := range querySet { t.Run(tt.WriteQuery, func(t *testing.T) { + if tt.Skip { + t.Skip() + return + } AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzererrors.ErrReadOnlyDatabase) }) } @@ -1235,10 +1239,16 @@ func TestDelete(t *testing.T, harness Harness) { for name, coster := range biasedCosters { t.Run(name+" join", func(t *testing.T) { for _, tt := range queries.DeleteJoinTests { - e := mustNewEngine(t, harness) - e.EngineAnalyzer().Coster = coster - defer e.Close() - RunWriteQueryTestWithEngine(t, harness, e, tt) + t.Run(tt.WriteQuery, func(t *testing.T) { + if tt.Skip { + t.Skip() + return + } + e := mustNewEngine(t, harness) + e.EngineAnalyzer().Coster = coster + defer e.Close() + RunWriteQueryTestWithEngine(t, harness, e, tt) + }) } }) } diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index eb9f987c24..62be52745d 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -1046,37 +1046,45 @@ func ExtractQueryNode(node sql.Node) sql.Node { // RunWriteQueryTest runs the specified |tt| WriteQueryTest using the specified harness. func RunWriteQueryTest(t *testing.T, harness Harness, tt queries.WriteQueryTest) { - e := mustNewEngine(t, harness) - defer e.Close() - RunWriteQueryTestWithEngine(t, harness, e, tt) + t.Run(tt.WriteQuery, func(t *testing.T) { + if tt.Skip { + t.Skip() + return + } + e := mustNewEngine(t, harness) + defer e.Close() + RunWriteQueryTestWithEngine(t, harness, e, tt) + }) } // RunWriteQueryTestWithEngine runs the specified |tt| WriteQueryTest, using the specified harness and engine. Callers // are still responsible for closing the engine. func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.WriteQueryTest) { - t.Run(tt.WriteQuery, func(t *testing.T) { - if sh, ok := harness.(SkippingHarness); ok { - if sh.SkipQueryTest(tt.WriteQuery) { - t.Logf("Skipping query %s", tt.WriteQuery) - return - } - if sh.SkipQueryTest(tt.SelectQuery) { - t.Logf("Skipping query %s", tt.SelectQuery) - return - } + if sh, ok := harness.(SkippingHarness); ok { + if sh.SkipQueryTest(tt.WriteQuery) { + t.Logf("Skipping query %s", tt.WriteQuery) + return } - ctx := NewContext(harness) - TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil) - expectedSelect := tt.ExpectedSelect - if IsServerEngine(e) && tt.SkipServerEngine { - expectedSelect = nil + if sh.SkipQueryTest(tt.SelectQuery) { + t.Logf("Skipping query %s", tt.SelectQuery) + return } - TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil) - }) + } + ctx := NewContext(harness) + TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil) + expectedSelect := tt.ExpectedSelect + if IsServerEngine(e) && tt.SkipServerEngine { + expectedSelect = nil + } + TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil) } func runWriteQueryTestPrepared(t *testing.T, harness Harness, tt queries.WriteQueryTest) { t.Run(tt.WriteQuery, func(t *testing.T) { + if tt.Skip { + t.Skip() + return + } if sh, ok := harness.(SkippingHarness); ok { if sh.SkipQueryTest(tt.WriteQuery) { t.Logf("Skipping query %s", tt.WriteQuery) diff --git a/enginetest/queries/insert_queries.go b/enginetest/queries/insert_queries.go index 4c807f039c..7e40a13c54 100644 --- a/enginetest/queries/insert_queries.go +++ b/enginetest/queries/insert_queries.go @@ -476,6 +476,20 @@ var InsertQueries = []WriteQueryTest{ SelectQuery: "SELECT * FROM mytable WHERE i = 1", ExpectedSelect: []sql.Row{{int64(1), "hi"}}, }, + { + WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt(new_i,new_s) ON DUPLICATE KEY UPDATE s=new_s", + ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}}, + SelectQuery: "SELECT * FROM mytable WHERE i = 1", + ExpectedSelect: []sql.Row{{int64(1), "hi"}}, + Skip: true, // https://github.com/dolthub/dolt/issues/7638 + }, + { + WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt ON DUPLICATE KEY UPDATE mytable.s=dt.s", + ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}}, + SelectQuery: "SELECT * FROM mytable WHERE i = 1", + ExpectedSelect: []sql.Row{{int64(1), "hir"}}, + Skip: true, // https://github.com/dolthub/dolt/issues/7638 + }, { WriteQuery: "INSERT INTO mytable (s,i) values ('dup',1) ON DUPLICATE KEY UPDATE s=CONCAT(VALUES(s), 'licate')", ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}}, @@ -1881,6 +1895,18 @@ var InsertScripts = []ScriptTest{ }, }, }, + { + Name: "insert on duplicate key with incorrect row alias", + SetUpScript: []string{ + `create table a (i int primary key)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `insert into a values (1) as new(c, d) on duplicate key update i = c`, + ExpectedErr: sql.ErrColumnCountMismatch, + }, + }, + }, { Name: "Insert throws primary key violations", SetUpScript: []string{ diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 7ed22a0bad..fc06fcb9ef 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -10763,6 +10763,7 @@ type WriteQueryTest struct { SelectQuery string ExpectedSelect []sql.Row Bindings map[string]*query.BindVariable + Skip bool SkipServerEngine bool } diff --git a/go.mod b/go.mod index 3c7aa1fcb6..2d85535d62 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df + github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 1972d94588..e26cf9664d 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df h1:hXB89Qhyu0ymVhP4AvuCtHWGpQmZN0Tt5Cc58Ig8/dg= -github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80 h1:BG7DheiFrbvKYtPmZ1avXA/VPKzz+Bv7L0ytUi83kyQ= +github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 0a9990dfb6..4e7f9ff599 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -15,6 +15,7 @@ package planbuilder import ( + "errors" "fmt" "strings" @@ -128,7 +129,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName switch v := ir.(type) { case ast.SelectStatement: return b.buildSelectStmt(inScope, v) - case ast.Values: + case *ast.AliasedValues: outScope = b.buildInsertValues(inScope, v, columnNames, tableName, destSchema) default: err := sql.ErrUnsupportedSyntax.New(ast.String(ir)) @@ -137,7 +138,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName return } -func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) { +func (b *Builder) buildInsertValues(inScope *scope, v *ast.AliasedValues, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) { columnDefaultValues := make([]*sql.ColumnDefaultValue, len(columnNames)) for i, columnName := range columnNames { @@ -156,8 +157,22 @@ func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames [] } } - exprTuples := make([][]sql.Expression, len(v)) - for i, vt := range v { + if !v.As.IsEmpty() { + if len(v.Columns) != 0 { + for _, tuple := range v.Values { + if len(v.Columns) != len(tuple) { + err := sql.ErrColumnCountMismatch.New() + b.handleErr(err) + } + } + + err := errors.New("insert row aliases are not currently supported; use the VALUES() function instead") + b.handleErr(err) + } + } + + exprTuples := make([][]sql.Expression, len(v.Values)) + for i, vt := range v.Values { // noExprs is an edge case where we fill VALUES with nil expressions noExprs := len(vt) == 0 // triggerUnknownTable is an edge case where we ignored an unresolved @@ -217,10 +232,10 @@ func reorderSchema(names []string, schema sql.Schema) sql.Schema { return newSch } -func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) { +func (b *Builder) buildValues(inScope *scope, v ast.AliasedValues) (outScope *scope) { // TODO add literals to outScope? - exprTuples := make([][]sql.Expression, len(v)) - for i, vt := range v { + exprTuples := make([][]sql.Expression, len(v.Values)) + for i, vt := range v.Values { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt {