diff --git a/dumpling/cmd/dumpling/main.go b/dumpling/cmd/dumpling/main.go index 52816fe0..fe93d773 100644 --- a/dumpling/cmd/dumpling/main.go +++ b/dumpling/cmd/dumpling/main.go @@ -27,7 +27,7 @@ import ( "github.com/docker/go-units" "github.com/pingcap/dumpling/v4/cli" "github.com/pingcap/dumpling/v4/export" - "github.com/pingcap/log" + "github.com/pingcap/dumpling/v4/log" filter "github.com/pingcap/tidb-tools/pkg/table-filter" "github.com/spf13/pflag" "go.uber.org/zap" diff --git a/dumpling/v4/export/connectionsPool.go b/dumpling/v4/export/connectionsPool.go new file mode 100644 index 00000000..0e8da805 --- /dev/null +++ b/dumpling/v4/export/connectionsPool.go @@ -0,0 +1,36 @@ +package export + +import ( + "context" + "database/sql" +) + +type connectionsPool struct { + conns chan *sql.Conn +} + +func newConnectionsPool(ctx context.Context, n int, pool *sql.DB) (*connectionsPool, error) { + connectPool := &connectionsPool{ + conns: make(chan *sql.Conn, n), + } + for i := 0; i < n; i++ { + conn, err := createConnWithConsistency(ctx, pool) + if err != nil { + return nil, err + } + connectPool.releaseConn(conn) + } + return connectPool, nil +} + +func (r *connectionsPool) getConn() *sql.Conn { + return <-r.conns +} + +func (r *connectionsPool) releaseConn(conn *sql.Conn) { + select { + case r.conns <- conn: + default: + panic("put a redundant conn") + } +} diff --git a/dumpling/v4/export/dump.go b/dumpling/v4/export/dump.go index 4d9901e8..e42e4a0f 100644 --- a/dumpling/v4/export/dump.go +++ b/dumpling/v4/export/dump.go @@ -66,10 +66,15 @@ func Dump(pCtx context.Context, conf *Config) (err error) { } if conf.Snapshot == "" && (doPdGC || conf.Consistency == "snapshot") { - conf.Snapshot, err = getSnapshot(pool) + conn, err := pool.Conn(ctx) + if err != nil { + return withStack(err) + } + conf.Snapshot, err = getSnapshot(conn) if err != nil { return err } + conn.Close() } if conf.Snapshot != "" { @@ -100,9 +105,10 @@ func Dump(pCtx context.Context, conf *Config) (err error) { "After dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '10m' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n") } - pool, err = resetDBWithSessionParams(pool, conf.getDSN(""), conf.SessionParams) - if err != nil { - return err + if newPool, err := resetDBWithSessionParams(pool, conf.getDSN(""), conf.SessionParams); err != nil { + return withStack(err) + } else { + pool = newPool } m := newGlobalMetadata(conf.OutputDirPath) @@ -112,14 +118,19 @@ func Dump(pCtx context.Context, conf *Config) (err error) { // for consistency lock, we should lock tables at first to get the tables we want to lock & dump // for consistency lock, record meta pos before lock tables because other tables may still be modified while locking tables if conf.Consistency == "lock" { + conn, err := createConnWithConsistency(ctx, pool) + if err != nil { + return err + } m.recordStartTime(time.Now()) - err = m.recordGlobalMetaData(pool, conf.ServerInfo.ServerType) + err = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType) if err != nil { log.Info("get global metadata failed", zap.Error(err)) } - if err = prepareTableListToDump(conf, pool); err != nil { + if err = prepareTableListToDump(conf, conn); err != nil { return err } + conn.Close() } conCtrl, err := NewConsistencyController(conf, pool) @@ -130,17 +141,28 @@ func Dump(pCtx context.Context, conf *Config) (err error) { return err } + connectPool, err := newConnectionsPool(ctx, conf.Threads, pool) + if err != nil { + return err + } + + if err = conCtrl.TearDown(); err != nil { + return err + } + // for other consistencies, we should get table list after consistency is set up and GlobalMetaData is cached // for other consistencies, record snapshot after whole tables are locked. The recorded meta info is exactly the locked snapshot. if conf.Consistency != "lock" { m.recordStartTime(time.Now()) - err = m.recordGlobalMetaData(pool, conf.ServerInfo.ServerType) + conn := connectPool.getConn() + err = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType) if err != nil { log.Info("get global metadata failed", zap.Error(err)) } - if err = prepareTableListToDump(conf, pool); err != nil { + if err = prepareTableListToDump(conf, conn); err != nil { return err } + connectPool.releaseConn(conn) } var writer Writer @@ -155,24 +177,26 @@ func Dump(pCtx context.Context, conf *Config) (err error) { } if conf.Sql == "" { - if err = dumpDatabases(ctx, conf, pool, writer); err != nil { + if err = dumpDatabases(ctx, conf, connectPool, writer); err != nil { return err } } else { - if err = dumpSql(ctx, conf, pool, writer); err != nil { + if err = dumpSql(ctx, conf, connectPool, writer); err != nil { return err } } m.recordFinishTime(time.Now()) - - return conCtrl.TearDown() + return nil } -func dumpDatabases(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error { +func dumpDatabases(ctx context.Context, conf *Config, connectPool *connectionsPool, writer Writer) error { allTables := conf.Tables + var g errgroup.Group for dbName, tables := range allTables { - createDatabaseSQL, err := ShowCreateDatabase(db, dbName) + conn := connectPool.getConn() + createDatabaseSQL, err := ShowCreateDatabase(conn, dbName) + connectPool.releaseConn(conn) if err != nil { return err } @@ -183,24 +207,35 @@ func dumpDatabases(ctx context.Context, conf *Config, db *sql.DB, writer Writer) if len(tables) == 0 { continue } - rateLimit := newRateLimit(conf.Threads) - var g errgroup.Group for _, table := range tables { table := table - g.Go(func() error { - rateLimit.getToken() - defer rateLimit.putToken() - return dumpTable(ctx, conf, db, dbName, table, writer) - }) - } - if err := g.Wait(); err != nil { - return err + conn := connectPool.getConn() + tableDataIRArray, err := dumpTable(ctx, conf, conn, dbName, table, writer) + if err != nil { + return err + } + connectPool.releaseConn(conn) + for _, tableIR := range tableDataIRArray { + tableIR := tableIR + g.Go(func() error { + conn := connectPool.getConn() + defer connectPool.releaseConn(conn) + err := tableIR.Start(ctx, conn) + if err != nil { + return err + } + return writer.WriteTableData(ctx, tableIR) + }) + } } } + if err := g.Wait(); err != nil { + return err + } return nil } -func prepareTableListToDump(conf *Config, pool *sql.DB) error { +func prepareTableListToDump(conf *Config, pool *sql.Conn) error { databases, err := prepareDumpingDatabases(conf, pool) if err != nil { return err @@ -223,8 +258,10 @@ func prepareTableListToDump(conf *Config, pool *sql.DB) error { return nil } -func dumpSql(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error { - tableIR, err := SelectFromSql(conf, db) +func dumpSql(ctx context.Context, conf *Config, connectPool *connectionsPool, writer Writer) error { + conn := connectPool.getConn() + tableIR, err := SelectFromSql(conf, conn) + connectPool.releaseConn(conn) if err != nil { return err } @@ -232,45 +269,45 @@ func dumpSql(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error return writer.WriteTableData(ctx, tableIR) } -func dumpTable(ctx context.Context, conf *Config, db *sql.DB, dbName string, table *TableInfo, writer Writer) error { +func dumpTable(ctx context.Context, conf *Config, db *sql.Conn, dbName string, table *TableInfo, writer Writer) ([]TableDataIR, error) { tableName := table.Name if !conf.NoSchemas { if table.Type == TableTypeView { viewName := table.Name createViewSQL, err := ShowCreateView(db, dbName, viewName) if err != nil { - return err + return nil, err } - return writer.WriteTableMeta(ctx, dbName, viewName, createViewSQL) + return nil, writer.WriteTableMeta(ctx, dbName, viewName, createViewSQL) } createTableSQL, err := ShowCreateTable(db, dbName, tableName) if err != nil { - return err + return nil, err } if err := writer.WriteTableMeta(ctx, dbName, tableName, createTableSQL); err != nil { - return err + return nil, err } } // Do not dump table data and return nil if conf.NoData { - return nil + return nil, nil } if conf.Rows != UnspecifiedSize { - finished, err := concurrentDumpTable(ctx, writer, conf, db, dbName, tableName) + finished, chunksIterArray, err := concurrentDumpTable(ctx, conf, db, dbName, tableName) if err != nil || finished { - return err + return chunksIterArray, err } } tableIR, err := SelectAllFromTable(conf, db, dbName, tableName) if err != nil { - return err + return nil, err } - return writer.WriteTableData(ctx, tableIR) + return []TableDataIR{tableIR}, nil } -func concurrentDumpTable(ctx context.Context, writer Writer, conf *Config, db *sql.DB, dbName string, tableName string) (bool, error) { +func concurrentDumpTable(ctx context.Context, conf *Config, db *sql.Conn, dbName string, tableName string) (bool, []TableDataIR, error) { // try dump table concurrently by split table to chunks chunksIterCh := make(chan TableDataIR, defaultDumpThreads) errCh := make(chan error, defaultDumpThreads) @@ -279,6 +316,7 @@ func concurrentDumpTable(ctx context.Context, writer Writer, conf *Config, db *s ctx1, cancel1 := context.WithCancel(ctx) defer cancel1() var g errgroup.Group + chunksIterArray := make([]TableDataIR, 0) g.Go(func() error { splitTableDataIntoChunks(ctx1, chunksIterCh, errCh, linear, dbName, tableName, db, conf) return nil @@ -288,24 +326,22 @@ Loop: for { select { case <-ctx.Done(): - return true, nil + return true, chunksIterArray, nil case <-linear: - return false, nil + return false, chunksIterArray, nil case chunksIter, ok := <-chunksIterCh: if !ok { break Loop } - g.Go(func() error { - return writer.WriteTableData(ctx, chunksIter) - }) + chunksIterArray = append(chunksIterArray, chunksIter) case err := <-errCh: - return false, err + return false, chunksIterArray, err } } if err := g.Wait(); err != nil { - return true, err + return true, chunksIterArray, err } - return true, nil + return true, chunksIterArray, nil } func updateServiceSafePoint(ctx context.Context, pdClient pd.Client, ttl int64, snapshotTS uint64) { diff --git a/dumpling/v4/export/dump_test.go b/dumpling/v4/export/dump_test.go index afa7163a..18055eac 100644 --- a/dumpling/v4/export/dump_test.go +++ b/dumpling/v4/export/dump_test.go @@ -2,6 +2,7 @@ package export import ( "context" + "database/sql" "fmt" "strconv" @@ -27,6 +28,14 @@ func newMockWriter() *mockWriter { } } +func newMockConnectPool(c *C, db *sql.DB) *connectionsPool { + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) + connectPool := &connectionsPool{conns: make(chan *sql.Conn, 1)} + connectPool.releaseConn(conn) + return connectPool +} + func (m *mockWriter) WriteDatabaseMeta(ctx context.Context, db, createSQL string) error { m.databaseMeta[db] = createSQL return nil @@ -64,7 +73,8 @@ func (s *testDumpSuite) TestDumpDatabase(c *C) { mock.ExpectQuery("SELECT (.) FROM `test`.`t`").WillReturnRows(rows) mockWriter := newMockWriter() - err = dumpDatabases(context.Background(), mockConfig, db, mockWriter) + connectPool := newMockConnectPool(c, db) + err = dumpDatabases(context.Background(), mockConfig, connectPool, mockWriter) c.Assert(err, IsNil) c.Assert(len(mockWriter.databaseMeta), Equals, 1) @@ -78,6 +88,8 @@ func (s *testDumpSuite) TestDumpTable(c *C) { mockConfig.SortByPk = false db, mock, err := sqlmock.New() c.Assert(err, IsNil) + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) showCreateTableResult := "CREATE TABLE t (a INT)" rows := mock.NewRows([]string{"Table", "Create Table"}).AddRow("t", showCreateTableResult) @@ -90,8 +102,13 @@ func (s *testDumpSuite) TestDumpTable(c *C) { mock.ExpectQuery("SELECT (.) FROM `test`.`t`").WillReturnRows(rows) mockWriter := newMockWriter() - err = dumpTable(context.Background(), mockConfig, db, "test", &TableInfo{Name: "t"}, mockWriter) + ctx := context.Background() + tableIRArray, err := dumpTable(ctx, mockConfig, conn, "test", &TableInfo{Name: "t"}, mockWriter) c.Assert(err, IsNil) + for _, tableIR := range tableIRArray { + c.Assert(tableIR.Start(ctx, conn), IsNil) + c.Assert(mockWriter.WriteTableData(ctx, tableIR), IsNil) + } c.Assert(mockWriter.tableMeta["test.t"], Equals, showCreateTableResult) c.Assert(len(mockWriter.tableData), Equals, 1) @@ -121,6 +138,8 @@ func (s *testDumpSuite) TestDumpTableWhereClause(c *C) { mockConfig.Where = "a > 3 and a < 9" db, mock, err := sqlmock.New() c.Assert(err, IsNil) + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) showCreateTableResult := "CREATE TABLE t (a INT)" rows := mock.NewRows([]string{"Table", "Create Table"}).AddRow("t", showCreateTableResult) @@ -137,8 +156,13 @@ func (s *testDumpSuite) TestDumpTableWhereClause(c *C) { mock.ExpectQuery("SELECT (.) FROM `test`.`t` WHERE a > 3 and a < 9").WillReturnRows(rows) mockWriter := newMockWriter() - err = dumpTable(context.Background(), mockConfig, db, "test", &TableInfo{Name: "t"}, mockWriter) + ctx := context.Background() + tableIRArray, err := dumpTable(ctx, mockConfig, conn, "test", &TableInfo{Name: "t"}, mockWriter) c.Assert(err, IsNil) + for _, tableIR := range tableIRArray { + c.Assert(tableIR.Start(ctx, conn), IsNil) + c.Assert(mockWriter.WriteTableData(ctx, tableIR), IsNil) + } c.Assert(mockWriter.tableMeta["test.t"], Equals, showCreateTableResult) c.Assert(len(mockWriter.tableData), Equals, 1) diff --git a/dumpling/v4/export/ir.go b/dumpling/v4/export/ir.go index 71a71a3f..d7dc33e7 100644 --- a/dumpling/v4/export/ir.go +++ b/dumpling/v4/export/ir.go @@ -2,11 +2,13 @@ package export import ( "bytes" + "context" "database/sql" ) // TableDataIR is table data intermediate representation. type TableDataIR interface { + Start(context.Context, *sql.Conn) error DatabaseName() string TableName() string ChunkIndex() int diff --git a/dumpling/v4/export/ir_impl.go b/dumpling/v4/export/ir_impl.go index 321ca549..5da9a54a 100644 --- a/dumpling/v4/export/ir_impl.go +++ b/dumpling/v4/export/ir_impl.go @@ -146,6 +146,7 @@ func (m *stringIter) HasNext() bool { type tableData struct { database string table string + query string chunkIndex int rows *sql.Rows colTypes []*sql.ColumnType @@ -154,6 +155,15 @@ type tableData struct { escapeBackslash bool } +func (td *tableData) Start(ctx context.Context, conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, td.query) + if err != nil { + return err + } + td.rows = rows + return nil +} + func (td *tableData) ColumnTypes() []string { colTypes := make([]string, len(td.colTypes)) for i, ct := range td.colTypes { @@ -241,7 +251,7 @@ func splitTableDataIntoChunks( tableDataIRCh chan TableDataIR, errCh chan error, linear chan struct{}, - dbName, tableName string, db *sql.DB, conf *Config) { + dbName, tableName string, db *sql.Conn, conf *Config) { field, err := pickupPossibleField(dbName, tableName, db, conf) if err != nil { errCh <- withStack(err) @@ -263,7 +273,7 @@ func splitTableDataIntoChunks( var smin sql.NullString var smax sql.NullString - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) err = row.Scan(&smin, &smax) if err != nil { log.Error("split chunks - get max min failed", zap.String("query", query), zap.Error(err)) @@ -329,11 +339,6 @@ LOOP: chunkIndex += 1 where := fmt.Sprintf("%s(`%s` >= %d AND `%s` < %d)", nullValueCondition, escapeString(field), cutoff, escapeString(field), cutoff+estimatedStep) query = buildSelectQuery(dbName, tableName, selectedField, buildWhereCondition(conf, where), orderByClause) - rows, err := db.Query(query) - if err != nil { - errCh <- errors.WithMessage(err, query) - return - } if len(nullValueCondition) > 0 { nullValueCondition = "" } @@ -341,7 +346,7 @@ LOOP: td := &tableData{ database: dbName, table: tableName, - rows: rows, + query: query, chunkIndex: chunkIndex, colTypes: colTypes, selectedField: selectedField, diff --git a/dumpling/v4/export/metadata.go b/dumpling/v4/export/metadata.go index 36e14af4..63d5d8c8 100644 --- a/dumpling/v4/export/metadata.go +++ b/dumpling/v4/export/metadata.go @@ -2,6 +2,7 @@ package export import ( "bytes" + "context" "database/sql" "errors" "fmt" @@ -46,7 +47,7 @@ func (m *globalMetadata) recordFinishTime(t time.Time) { m.buffer.WriteString("Finished dump at: " + t.Format(metadataTimeLayout) + "\n") } -func (m *globalMetadata) recordGlobalMetaData(db *sql.DB, serverType ServerType) error { +func (m *globalMetadata) recordGlobalMetaData(db *sql.Conn, serverType ServerType) error { // get master status info m.buffer.WriteString("SHOW MASTER STATUS:\n") switch serverType { @@ -107,7 +108,7 @@ func (m *globalMetadata) recordGlobalMetaData(db *sql.DB, serverType ServerType) m.buffer.WriteString("\tPos: " + pos + "\n") } var gtidSet string - err = db.QueryRow("SELECT @@global.gtid_binlog_pos").Scan(>idSet) + err = db.QueryRowContext(context.Background(), "SELECT @@global.gtid_binlog_pos").Scan(>idSet) if err != nil { return err } diff --git a/dumpling/v4/export/metadata_test.go b/dumpling/v4/export/metadata_test.go index e412ed96..0918f38d 100644 --- a/dumpling/v4/export/metadata_test.go +++ b/dumpling/v4/export/metadata_test.go @@ -1,6 +1,7 @@ package export import ( + "context" "fmt" "path" @@ -16,6 +17,8 @@ func (s *testMetaDataSuite) TestMysqlMetaData(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) logFile := "ON.000001" pos := "7502" @@ -29,7 +32,7 @@ func (s *testMetaDataSuite) TestMysqlMetaData(c *C) { testFilePath := "/test" m := newGlobalMetadata(testFilePath) - c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil) + c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil) c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath)) c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+ @@ -43,6 +46,8 @@ func (s *testMetaDataSuite) TestMysqlWithFollowersMetaData(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) logFile := "ON.000001" pos := "7502" @@ -57,7 +62,7 @@ func (s *testMetaDataSuite) TestMysqlWithFollowersMetaData(c *C) { testFilePath := "/test" m := newGlobalMetadata(testFilePath) - c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil) + c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil) c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath)) c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+ @@ -76,6 +81,8 @@ func (s *testMetaDataSuite) TestMariaDBMetaData(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) logFile := "mariadb-bin.000016" pos := "475" @@ -89,7 +96,7 @@ func (s *testMetaDataSuite) TestMariaDBMetaData(c *C) { mock.ExpectQuery("SHOW SLAVE STATUS").WillReturnRows(rows) testFilePath := "/test" m := newGlobalMetadata(testFilePath) - c.Assert(m.recordGlobalMetaData(db, ServerTypeMariaDB), IsNil) + c.Assert(m.recordGlobalMetaData(conn, ServerTypeMariaDB), IsNil) c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath)) c.Assert(mock.ExpectationsWereMet(), IsNil) @@ -99,6 +106,8 @@ func (s *testMetaDataSuite) TestMariaDBWithFollowersMetaData(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) logFile := "ON.000001" pos := "7502" @@ -116,7 +125,7 @@ func (s *testMetaDataSuite) TestMariaDBWithFollowersMetaData(c *C) { testFilePath := "/test" m := newGlobalMetadata(testFilePath) - c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil) + c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil) c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath)) c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+ diff --git a/dumpling/v4/export/prepare.go b/dumpling/v4/export/prepare.go index 3d123c73..91ea562b 100644 --- a/dumpling/v4/export/prepare.go +++ b/dumpling/v4/export/prepare.go @@ -64,7 +64,7 @@ func detectServerInfo(db *sql.DB) (ServerInfo, error) { return ParseServerInfo(versionStr), nil } -func prepareDumpingDatabases(conf *Config, db *sql.DB) ([]string, error) { +func prepareDumpingDatabases(conf *Config, db *sql.Conn) ([]string, error) { databases, err := ShowDatabases(db) if len(conf.Databases) == 0 { return databases, err @@ -86,12 +86,12 @@ func prepareDumpingDatabases(conf *Config, db *sql.DB) ([]string, error) { } } -func listAllTables(db *sql.DB, databaseNames []string) (DatabaseTables, error) { +func listAllTables(db *sql.Conn, databaseNames []string) (DatabaseTables, error) { log.Debug("list all the tables") return ListAllDatabasesTables(db, databaseNames, TableTypeBase) } -func listAllViews(db *sql.DB, databaseNames []string) (DatabaseTables, error) { +func listAllViews(db *sql.Conn, databaseNames []string) (DatabaseTables, error) { log.Debug("list all the views") return ListAllDatabasesTables(db, databaseNames, TableTypeView) } diff --git a/dumpling/v4/export/prepare_test.go b/dumpling/v4/export/prepare_test.go index 7233e925..3fc038b5 100644 --- a/dumpling/v4/export/prepare_test.go +++ b/dumpling/v4/export/prepare_test.go @@ -1,6 +1,7 @@ package export import ( + "context" "fmt" "github.com/DATA-DOG/go-sqlmock" @@ -15,6 +16,8 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) rows := sqlmock.NewRows([]string{"Database"}). AddRow("db1"). @@ -24,7 +27,7 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) { mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows) conf := DefaultConfig() conf.Databases = []string{"db1", "db2", "db3"} - result, err := prepareDumpingDatabases(conf, db) + result, err := prepareDumpingDatabases(conf, conn) c.Assert(err, IsNil) c.Assert(result, DeepEquals, []string{"db1", "db2", "db3"}) @@ -33,12 +36,12 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) { AddRow("db1"). AddRow("db2") mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows) - result, err = prepareDumpingDatabases(conf, db) + result, err = prepareDumpingDatabases(conf, conn) c.Assert(err, IsNil) c.Assert(result, DeepEquals, []string{"db1", "db2"}) mock.ExpectQuery("SHOW DATABASES").WillReturnError(fmt.Errorf("err")) - _, err = prepareDumpingDatabases(conf, db) + _, err = prepareDumpingDatabases(conf, conn) c.Assert(err, NotNil) rows = sqlmock.NewRows([]string{"Database"}). @@ -48,7 +51,7 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) { AddRow("db5") mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows) conf.Databases = []string{"db1", "db2", "db4", "db6"} - _, err = prepareDumpingDatabases(conf, db) + _, err = prepareDumpingDatabases(conf, conn) c.Assert(err, ErrorMatches, `Unknown databases \[db4,db6\]`) c.Assert(mock.ExpectationsWereMet(), IsNil) } @@ -57,6 +60,8 @@ func (s *testPrepareSuite) TestListAllTables(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) data := NewDatabaseTables(). AppendTables("db1", "t1", "t2"). @@ -78,7 +83,7 @@ func (s *testPrepareSuite) TestListAllTables(c *C) { query := "SELECT table_schema,table_name FROM information_schema.tables WHERE table_type = (.*)" mock.ExpectQuery(query).WillReturnRows(rows) - tables, err := listAllTables(db, dbNames) + tables, err := listAllTables(conn, dbNames) c.Assert(err, IsNil) for d, t := range tables { @@ -96,7 +101,7 @@ func (s *testPrepareSuite) TestListAllTables(c *C) { AppendViews("db", "t2") query = "SELECT table_schema,table_name FROM information_schema.tables WHERE table_type = (.*)" mock.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"table_schema", "table_name"}).AddRow("db", "t2")) - tables, err = listAllViews(db, []string{"db"}) + tables, err = listAllViews(conn, []string{"db"}) c.Assert(err, IsNil) c.Assert(len(tables), Equals, 1) c.Assert(len(tables["db"]), Equals, 1) diff --git a/dumpling/v4/export/ratelimit.go b/dumpling/v4/export/ratelimit.go deleted file mode 100644 index aafb5695..00000000 --- a/dumpling/v4/export/ratelimit.go +++ /dev/null @@ -1,23 +0,0 @@ -package export - -type rateLimit struct { - token chan struct{} -} - -func newRateLimit(n int) *rateLimit { - return &rateLimit{ - token: make(chan struct{}, n), - } -} - -func (r *rateLimit) getToken() { - r.token <- struct{}{} -} - -func (r *rateLimit) putToken() { - select { - case <-r.token: - default: - panic("put a redundant token") - } -} diff --git a/dumpling/v4/export/sql.go b/dumpling/v4/export/sql.go index 4a9396ea..a5bfbfd0 100644 --- a/dumpling/v4/export/sql.go +++ b/dumpling/v4/export/sql.go @@ -1,6 +1,7 @@ package export import ( + "context" "database/sql" "fmt" "net/url" @@ -13,7 +14,7 @@ import ( "github.com/pingcap/dumpling/v4/log" ) -func ShowDatabases(db *sql.DB) ([]string, error) { +func ShowDatabases(db *sql.Conn) ([]string, error) { var res oneStrColumnTable if err := simpleQuery(db, "SHOW DATABASES", res.handleOneRow); err != nil { return nil, err @@ -22,7 +23,7 @@ func ShowDatabases(db *sql.DB) ([]string, error) { } // ShowTables shows the tables of a database, the caller should use the correct database. -func ShowTables(db *sql.DB) ([]string, error) { +func ShowTables(db *sql.Conn) ([]string, error) { var res oneStrColumnTable if err := simpleQuery(db, "SHOW TABLES", res.handleOneRow); err != nil { return nil, err @@ -30,7 +31,7 @@ func ShowTables(db *sql.DB) ([]string, error) { return res.data, nil } -func ShowCreateDatabase(db *sql.DB, database string) (string, error) { +func ShowCreateDatabase(db *sql.Conn, database string) (string, error) { var oneRow [2]string handleOneRow := func(rows *sql.Rows) error { return rows.Scan(&oneRow[0], &oneRow[1]) @@ -43,7 +44,7 @@ func ShowCreateDatabase(db *sql.DB, database string) (string, error) { return oneRow[1], nil } -func ShowCreateTable(db *sql.DB, database, table string) (string, error) { +func ShowCreateTable(db *sql.Conn, database, table string) (string, error) { var oneRow [2]string handleOneRow := func(rows *sql.Rows) error { return rows.Scan(&oneRow[0], &oneRow[1]) @@ -56,7 +57,7 @@ func ShowCreateTable(db *sql.DB, database, table string) (string, error) { return oneRow[1], nil } -func ShowCreateView(db *sql.DB, database, view string) (string, error) { +func ShowCreateView(db *sql.Conn, database, view string) (string, error) { var oneRow [4]string handleOneRow := func(rows *sql.Rows) error { return rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3]) @@ -69,16 +70,7 @@ func ShowCreateView(db *sql.DB, database, view string) (string, error) { return oneRow[1], nil } -func ListAllTables(db *sql.DB, database string) ([]string, error) { - var tables oneStrColumnTable - const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'BASE TABLE'" - if err := simpleQueryWithArgs(db, tables.handleOneRow, query, database); err != nil { - return nil, errors.WithMessage(err, query) - } - return tables.data, nil -} - -func ListAllDatabasesTables(db *sql.DB, databaseNames []string, tableType TableType) (DatabaseTables, error) { +func ListAllDatabasesTables(db *sql.Conn, databaseNames []string, tableType TableType) (DatabaseTables, error) { var tableTypeStr string switch tableType { case TableTypeBase: @@ -112,7 +104,16 @@ func ListAllDatabasesTables(db *sql.DB, databaseNames []string, tableType TableT return dbTables, nil } -func ListAllViews(db *sql.DB, database string) ([]string, error) { +func ListAllTables(db *sql.Conn, database string) ([]string, error) { + var tables oneStrColumnTable + const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'BASE TABLE'" + if err := simpleQueryWithArgs(db, tables.handleOneRow, query, database); err != nil { + return nil, errors.WithMessage(err, query) + } + return tables.data, nil +} + +func ListAllViews(db *sql.Conn, database string) ([]string, error) { var views oneStrColumnTable const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'VIEW'" if err := simpleQueryWithArgs(db, views.handleOneRow, query, database); err != nil { @@ -123,17 +124,15 @@ func ListAllViews(db *sql.DB, database string) ([]string, error) { func SelectVersion(db *sql.DB) (string, error) { var versionInfo string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&versionInfo) - } - err := simpleQuery(db, "SELECT version()", handleOneRow) + row := db.QueryRow("SELECT version()") + err := row.Scan(&versionInfo) if err != nil { return "", withStack(err) } return versionInfo, nil } -func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (TableDataIR, error) { +func SelectAllFromTable(conf *Config, db *sql.Conn, database, table string) (TableDataIR, error) { selectedField, err := buildSelectField(db, database, table) if err != nil { return nil, err @@ -150,15 +149,11 @@ func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (Table } query := buildSelectQuery(database, table, selectedField, buildWhereCondition(conf, ""), orderByClause) - rows, err := db.Query(query) - if err != nil { - return nil, withStack(errors.WithMessage(err, query)) - } return &tableData{ database: database, table: table, - rows: rows, + query: query, colTypes: colTypes, selectedField: selectedField, escapeBackslash: conf.EscapeBackslash, @@ -168,8 +163,8 @@ func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (Table }, nil } -func SelectFromSql(conf *Config, db *sql.DB) (TableDataIR, error) { - rows, err := db.Query(conf.Sql) +func SelectFromSql(conf *Config, db *sql.Conn) (TableDataIR, error) { + rows, err := db.QueryContext(context.Background(), conf.Sql) if err != nil { return nil, withStack(errors.WithMessage(err, conf.Sql)) } @@ -220,7 +215,7 @@ func buildSelectQuery(database, table string, fields string, where string, order return query.String() } -func buildOrderByClause(conf *Config, db *sql.DB, database, table string) (string, error) { +func buildOrderByClause(conf *Config, db *sql.Conn, database, table string) (string, error) { if !conf.SortByPk { return "", nil } @@ -246,10 +241,10 @@ func buildOrderByClause(conf *Config, db *sql.DB, database, table string) (strin return "", nil } -func SelectTiDBRowID(db *sql.DB, database, table string) (bool, error) { +func SelectTiDBRowID(db *sql.Conn, database, table string) (bool, error) { const errBadFieldCode = 1054 tiDBRowIDQuery := fmt.Sprintf("SELECT _tidb_rowid from `%s`.`%s` LIMIT 0", escapeString(database), escapeString(table)) - _, err := db.Exec(tiDBRowIDQuery) + _, err := db.ExecContext(context.Background(), tiDBRowIDQuery) if err != nil { errMsg := strings.ToLower(err.Error()) if strings.Contains(errMsg, fmt.Sprintf("%d", errBadFieldCode)) { @@ -260,9 +255,9 @@ func SelectTiDBRowID(db *sql.DB, database, table string) (bool, error) { return true, nil } -func GetColumnTypes(db *sql.DB, fields, database, table string) ([]*sql.ColumnType, error) { +func GetColumnTypes(db *sql.Conn, fields, database, table string) ([]*sql.ColumnType, error) { query := fmt.Sprintf("SELECT %s FROM `%s`.`%s` LIMIT 1", fields, escapeString(database), escapeString(table)) - rows, err := db.Query(query) + rows, err := db.QueryContext(context.Background(), query) if err != nil { return nil, withStack(errors.WithMessage(err, query)) } @@ -270,11 +265,11 @@ func GetColumnTypes(db *sql.DB, fields, database, table string) ([]*sql.ColumnTy return rows.ColumnTypes() } -func GetPrimaryKeyName(db *sql.DB, database, table string) (string, error) { +func GetPrimaryKeyName(db *sql.Conn, database, table string) (string, error) { priKeyQuery := "SELECT column_name FROM information_schema.columns " + "WHERE table_schema = ? AND table_name = ? AND column_key = 'PRI';" var colName string - row := db.QueryRow(priKeyQuery, database, table) + row := db.QueryRowContext(context.Background(), priKeyQuery, database, table) if err := row.Scan(&colName); err != nil { if err == sql.ErrNoRows { return "", nil @@ -285,11 +280,11 @@ func GetPrimaryKeyName(db *sql.DB, database, table string) (string, error) { return colName, nil } -func GetUniqueIndexName(db *sql.DB, database, table string) (string, error) { +func GetUniqueIndexName(db *sql.Conn, database, table string) (string, error) { uniKeyQuery := "SELECT column_name FROM information_schema.columns " + "WHERE table_schema = ? AND table_name = ? AND column_key = 'UNI';" var colName string - row := db.QueryRow(uniKeyQuery, database, table) + row := db.QueryRowContext(context.Background(), uniKeyQuery, database, table) if err := row.Scan(&colName); err != nil { if err == sql.ErrNoRows { return "", nil @@ -320,7 +315,7 @@ func UseDatabase(db *sql.DB, databaseName string) error { return withStack(err) } -func ShowMasterStatus(db *sql.DB, fieldNum int) ([]string, error) { +func ShowMasterStatus(db *sql.Conn, fieldNum int) ([]string, error) { oneRow := make([]string, fieldNum) addr := make([]interface{}, fieldNum) for i := range oneRow { @@ -387,17 +382,15 @@ func GetTiDBDDLIDs(db *sql.DB) ([]string, error) { func CheckTiDBWithTiKV(db *sql.DB) (bool, error) { var count int - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&count) - } - err := simpleQuery(db, "SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'", handleOneRow) + row := db.QueryRow("SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'") + err := row.Scan(&count) if err != nil { return false, err } return count > 0, nil } -func getSnapshot(db *sql.DB) (string, error) { +func getSnapshot(db *sql.Conn) (string, error) { str, err := ShowMasterStatus(db, showMasterStatusFieldNum) if err != nil { return "", err @@ -445,9 +438,25 @@ func resetDBWithSessionParams(db *sql.DB, dsn string, params map[string]interfac return newDB, nil } -func buildSelectField(db *sql.DB, dbName, tableName string) (string, error) { +func createConnWithConsistency(ctx context.Context, db *sql.DB) (*sql.Conn, error) { + conn, err := db.Conn(ctx) + if err != nil { + return nil, err + } + _, err = conn.ExecContext(ctx, "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ") + if err != nil { + return nil, err + } + _, err = conn.ExecContext(ctx, "START TRANSACTION /*!40108 WITH CONSISTENT SNAPSHOT */") + if err != nil { + return nil, err + } + return conn, err +} + +func buildSelectField(db *sql.Conn, dbName, tableName string) (string, error) { query := `SELECT COLUMN_NAME,EXTRA FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA=? AND TABLE_NAME=?;` - rows, err := db.Query(query, dbName, tableName) + rows, err := db.QueryContext(context.Background(), query, dbName, tableName) if err != nil { return "", err } @@ -488,12 +497,12 @@ func (o *oneStrColumnTable) handleOneRow(rows *sql.Rows) error { return nil } -func simpleQuery(db *sql.DB, sql string, handleOneRow func(*sql.Rows) error) error { - return simpleQueryWithArgs(db, handleOneRow, sql) +func simpleQuery(conn *sql.Conn, sql string, handleOneRow func(*sql.Rows) error) error { + return simpleQueryWithArgs(conn, handleOneRow, sql) } -func simpleQueryWithArgs(db *sql.DB, handleOneRow func(*sql.Rows) error, sql string, args ...interface{}) error { - rows, err := db.Query(sql, args...) +func simpleQueryWithArgs(conn *sql.Conn, handleOneRow func(*sql.Rows) error, sql string, args ...interface{}) error { + rows, err := conn.QueryContext(context.Background(), sql, args...) if err != nil { return withStack(err) } @@ -507,7 +516,7 @@ func simpleQueryWithArgs(db *sql.DB, handleOneRow func(*sql.Rows) error, sql str return rows.Err() } -func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (string, error) { +func pickupPossibleField(dbName, tableName string, db *sql.Conn, conf *Config) (string, error) { // If detected server is TiDB, try using _tidb_rowid if conf.ServerInfo.ServerType == ServerTypeTiDB { ok, err := SelectTiDBRowID(db, dbName, tableName) @@ -539,7 +548,7 @@ func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (st query := "SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " + "WHERE TABLE_NAME = ? AND COLUMN_NAME = ?" var fieldType string - row := db.QueryRow(query, tableName, fieldName) + row := db.QueryRowContext(context.Background(), query, tableName, fieldName) err = row.Scan(&fieldType) if err != nil { if err == sql.ErrNoRows { @@ -555,7 +564,7 @@ func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (st return "", nil } -func estimateCount(dbName, tableName string, db *sql.DB, field string, conf *Config) uint64 { +func estimateCount(dbName, tableName string, db *sql.Conn, field string, conf *Config) uint64 { query := fmt.Sprintf("EXPLAIN SELECT `%s` FROM `%s`.`%s`", field, escapeString(dbName), escapeString(tableName)) if conf.Where != "" { @@ -592,13 +601,14 @@ func estimateCount(dbName, tableName string, db *sql.DB, field string, conf *Con return 0 } -func detectEstimateRows(db *sql.DB, query string, fieldNames []string) uint64 { - row, err := db.Query(query) +func detectEstimateRows(db *sql.Conn, query string, fieldNames []string) uint64 { + row, err := db.QueryContext(context.Background(), query) if err != nil { log.Warn("can't execute query from db", zap.String("query", query), zap.Error(err)) return 0 } + defer row.Close() row.Next() columns, _ := row.Columns() addr := make([]interface{}, len(columns)) @@ -638,19 +648,14 @@ func parseSnapshotToTSO(pool *sql.DB, snapshot string) (uint64, error) { return snapshotTS, nil } var tso sql.NullInt64 - err = simpleQueryWithArgs(pool, func(rows *sql.Rows) error { - err := rows.Scan(&tso) - if err != nil { - return err - } - if !tso.Valid { - return fmt.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot) - } - return nil - }, "SELECT unix_timestamp(?)", snapshot) + row := pool.QueryRow("SELECT unix_timestamp(?)", snapshot) + err = row.Scan(&tso) if err != nil { return 0, withStack(err) } + if !tso.Valid { + return 0, withStack(fmt.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot)) + } return (uint64(tso.Int64)<<18)*1000 + 1, nil } diff --git a/dumpling/v4/export/sql_test.go b/dumpling/v4/export/sql_test.go index 79bc2b10..ee02e534 100644 --- a/dumpling/v4/export/sql_test.go +++ b/dumpling/v4/export/sql_test.go @@ -1,6 +1,7 @@ package export import ( + "context" "errors" "github.com/DATA-DOG/go-sqlmock" @@ -54,6 +55,8 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) mockConf := DefaultConfig() mockConf.SortByPk = true @@ -65,14 +68,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { mock.ExpectExec("SELECT _tidb_rowid from `test`.`t`"). WillReturnResult(sqlmock.NewResult(0, 0)) - orderByClause, err := buildOrderByClause(mockConf, db, "test", "t") + orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t") c.Assert(err, IsNil) mock.ExpectQuery("SELECT COLUMN_NAME"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err := buildSelectField(db, "test", "t") + selectedField, err := buildSelectField(conn, "test", "t") c.Assert(err, IsNil) q := buildSelectQuery("test", "t", selectedField, "", orderByClause) c.Assert(q, Equals, "SELECT * FROM `test`.`t` ORDER BY _tidb_rowid") @@ -81,14 +84,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { mock.ExpectExec("SELECT _tidb_rowid from `test`.`t`"). WillReturnError(errors.New(`1054, "Unknown column '_tidb_rowid' in 'field list'"`)) - orderByClause, err = buildOrderByClause(mockConf, db, "test", "t") + orderByClause, err = buildOrderByClause(mockConf, conn, "test", "t") c.Assert(err, IsNil) mock.ExpectQuery("SELECT COLUMN_NAME"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err = buildSelectField(db, "test", "t") + selectedField, err = buildSelectField(conn, "test", "t") c.Assert(err, IsNil) q = buildSelectQuery("test", "t", selectedField, "", orderByClause) c.Assert(q, Equals, "SELECT * FROM `test`.`t`") @@ -104,14 +107,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { mock.ExpectQuery("SELECT column_name FROM information_schema.columns"). WithArgs("test", "t"). WillReturnRows(sqlmock.NewRows([]string{"column_name"}).AddRow("id")) - orderByClause, err := buildOrderByClause(mockConf, db, "test", "t") + orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t") c.Assert(err, IsNil, cmt) mock.ExpectQuery("SELECT COLUMN_NAME"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err = buildSelectField(db, "test", "t") + selectedField, err = buildSelectField(conn, "test", "t") c.Assert(err, IsNil) q = buildSelectQuery("test", "t", selectedField, "", orderByClause) c.Assert(q, Equals, "SELECT * FROM `test`.`t` ORDER BY `id`", cmt) @@ -128,14 +131,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { WithArgs("test", "t"). WillReturnRows(sqlmock.NewRows([]string{"column_name"})) - orderByClause, err := buildOrderByClause(mockConf, db, "test", "t") + orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t") c.Assert(err, IsNil, cmt) mock.ExpectQuery("SELECT COLUMN_NAME"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err = buildSelectField(db, "test", "t") + selectedField, err = buildSelectField(conn, "test", "t") c.Assert(err, IsNil) q := buildSelectQuery("test", "t", selectedField, "", orderByClause) c.Assert(q, Equals, "SELECT * FROM `test`.`t`", cmt) @@ -154,7 +157,7 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) { WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err := buildSelectField(db, "test", "t") + selectedField, err := buildSelectField(conn, "test", "t") c.Assert(err, IsNil) q := buildSelectQuery("test", "t", selectedField, "", "") c.Assert(q, Equals, "SELECT * FROM `test`.`t`", cmt) @@ -166,13 +169,15 @@ func (s *testDumpSuite) TestBuildSelectField(c *C) { db, mock, err := sqlmock.New() c.Assert(err, IsNil) defer db.Close() + conn, err := db.Conn(context.Background()) + c.Assert(err, IsNil) // generate columns not found mock.ExpectQuery("SELECT COLUMN_NAME"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", "")) - selectedField, err := buildSelectField(db, "test", "t") + selectedField, err := buildSelectField(conn, "test", "t") c.Assert(selectedField, Equals, "*") c.Assert(err, IsNil) c.Assert(mock.ExpectationsWereMet(), IsNil) @@ -183,7 +188,7 @@ func (s *testDumpSuite) TestBuildSelectField(c *C) { WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}). AddRow("id", "").AddRow("name", "").AddRow("quo`te", "").AddRow("generated", "VIRTUAL GENERATED")) - selectedField, err = buildSelectField(db, "test", "t") + selectedField, err = buildSelectField(conn, "test", "t") c.Assert(selectedField, Equals, "`id`,`name`,`quo``te`") c.Assert(err, IsNil) c.Assert(mock.ExpectationsWereMet(), IsNil) diff --git a/dumpling/v4/export/test_util.go b/dumpling/v4/export/test_util.go index 1c625541..f4fee268 100644 --- a/dumpling/v4/export/test_util.go +++ b/dumpling/v4/export/test_util.go @@ -1,6 +1,7 @@ package export import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -87,6 +88,10 @@ type mockTableIR struct { rowErr error } +func (m *mockTableIR) Start(ctx context.Context, conn *sql.Conn) error { + return nil +} + func (m *mockTableIR) DatabaseName() string { return m.dbName }