From bac3665748b2980a06890836cb298c20a2f8e6ab Mon Sep 17 00:00:00 2001
From: Chunzhu Li <lichunzhu@stu.xjtu.edu.cn>
Date: Mon, 3 Aug 2020 17:35:44 +0800
Subject: [PATCH] Optimize mysql consistency (#121)

* refine conn pool
---
 dumpling/cmd/dumpling/main.go         |   2 +-
 dumpling/v4/export/connectionsPool.go |  36 +++++++
 dumpling/v4/export/dump.go            | 128 ++++++++++++++++---------
 dumpling/v4/export/dump_test.go       |  30 +++++-
 dumpling/v4/export/ir.go              |   2 +
 dumpling/v4/export/ir_impl.go         |  21 ++--
 dumpling/v4/export/metadata.go        |   5 +-
 dumpling/v4/export/metadata_test.go   |  17 +++-
 dumpling/v4/export/prepare.go         |   6 +-
 dumpling/v4/export/prepare_test.go    |  17 ++--
 dumpling/v4/export/ratelimit.go       |  23 -----
 dumpling/v4/export/sql.go             | 133 +++++++++++++-------------
 dumpling/v4/export/sql_test.go        |  27 +++---
 dumpling/v4/export/test_util.go       |   5 +
 14 files changed, 281 insertions(+), 171 deletions(-)
 create mode 100644 dumpling/v4/export/connectionsPool.go
 delete mode 100644 dumpling/v4/export/ratelimit.go

diff --git a/dumpling/cmd/dumpling/main.go b/dumpling/cmd/dumpling/main.go
index 52816fe0..fe93d773 100644
--- a/dumpling/cmd/dumpling/main.go
+++ b/dumpling/cmd/dumpling/main.go
@@ -27,7 +27,7 @@ import (
 	"github.com/docker/go-units"
 	"github.com/pingcap/dumpling/v4/cli"
 	"github.com/pingcap/dumpling/v4/export"
-	"github.com/pingcap/log"
+	"github.com/pingcap/dumpling/v4/log"
 	filter "github.com/pingcap/tidb-tools/pkg/table-filter"
 	"github.com/spf13/pflag"
 	"go.uber.org/zap"
diff --git a/dumpling/v4/export/connectionsPool.go b/dumpling/v4/export/connectionsPool.go
new file mode 100644
index 00000000..0e8da805
--- /dev/null
+++ b/dumpling/v4/export/connectionsPool.go
@@ -0,0 +1,36 @@
+package export
+
+import (
+	"context"
+	"database/sql"
+)
+
+type connectionsPool struct {
+	conns chan *sql.Conn
+}
+
+func newConnectionsPool(ctx context.Context, n int, pool *sql.DB) (*connectionsPool, error) {
+	connectPool := &connectionsPool{
+		conns: make(chan *sql.Conn, n),
+	}
+	for i := 0; i < n; i++ {
+		conn, err := createConnWithConsistency(ctx, pool)
+		if err != nil {
+			return nil, err
+		}
+		connectPool.releaseConn(conn)
+	}
+	return connectPool, nil
+}
+
+func (r *connectionsPool) getConn() *sql.Conn {
+	return <-r.conns
+}
+
+func (r *connectionsPool) releaseConn(conn *sql.Conn) {
+	select {
+	case r.conns <- conn:
+	default:
+		panic("put a redundant conn")
+	}
+}
diff --git a/dumpling/v4/export/dump.go b/dumpling/v4/export/dump.go
index 4d9901e8..e42e4a0f 100644
--- a/dumpling/v4/export/dump.go
+++ b/dumpling/v4/export/dump.go
@@ -66,10 +66,15 @@ func Dump(pCtx context.Context, conf *Config) (err error) {
 	}
 
 	if conf.Snapshot == "" && (doPdGC || conf.Consistency == "snapshot") {
-		conf.Snapshot, err = getSnapshot(pool)
+		conn, err := pool.Conn(ctx)
+		if err != nil {
+			return withStack(err)
+		}
+		conf.Snapshot, err = getSnapshot(conn)
 		if err != nil {
 			return err
 		}
+		conn.Close()
 	}
 
 	if conf.Snapshot != "" {
@@ -100,9 +105,10 @@ func Dump(pCtx context.Context, conf *Config) (err error) {
 			"After dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '10m' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n")
 	}
 
-	pool, err = resetDBWithSessionParams(pool, conf.getDSN(""), conf.SessionParams)
-	if err != nil {
-		return err
+	if newPool, err := resetDBWithSessionParams(pool, conf.getDSN(""), conf.SessionParams); err != nil {
+		return withStack(err)
+	} else {
+		pool = newPool
 	}
 
 	m := newGlobalMetadata(conf.OutputDirPath)
@@ -112,14 +118,19 @@ func Dump(pCtx context.Context, conf *Config) (err error) {
 	// for consistency lock, we should lock tables at first to get the tables we want to lock & dump
 	// for consistency lock, record meta pos before lock tables because other tables may still be modified while locking tables
 	if conf.Consistency == "lock" {
+		conn, err := createConnWithConsistency(ctx, pool)
+		if err != nil {
+			return err
+		}
 		m.recordStartTime(time.Now())
-		err = m.recordGlobalMetaData(pool, conf.ServerInfo.ServerType)
+		err = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType)
 		if err != nil {
 			log.Info("get global metadata failed", zap.Error(err))
 		}
-		if err = prepareTableListToDump(conf, pool); err != nil {
+		if err = prepareTableListToDump(conf, conn); err != nil {
 			return err
 		}
+		conn.Close()
 	}
 
 	conCtrl, err := NewConsistencyController(conf, pool)
@@ -130,17 +141,28 @@ func Dump(pCtx context.Context, conf *Config) (err error) {
 		return err
 	}
 
+	connectPool, err := newConnectionsPool(ctx, conf.Threads, pool)
+	if err != nil {
+		return err
+	}
+
+	if err = conCtrl.TearDown(); err != nil {
+		return err
+	}
+
 	// for other consistencies, we should get table list after consistency is set up and GlobalMetaData is cached
 	// for other consistencies, record snapshot after whole tables are locked. The recorded meta info is exactly the locked snapshot.
 	if conf.Consistency != "lock" {
 		m.recordStartTime(time.Now())
-		err = m.recordGlobalMetaData(pool, conf.ServerInfo.ServerType)
+		conn := connectPool.getConn()
+		err = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType)
 		if err != nil {
 			log.Info("get global metadata failed", zap.Error(err))
 		}
-		if err = prepareTableListToDump(conf, pool); err != nil {
+		if err = prepareTableListToDump(conf, conn); err != nil {
 			return err
 		}
+		connectPool.releaseConn(conn)
 	}
 
 	var writer Writer
@@ -155,24 +177,26 @@ func Dump(pCtx context.Context, conf *Config) (err error) {
 	}
 
 	if conf.Sql == "" {
-		if err = dumpDatabases(ctx, conf, pool, writer); err != nil {
+		if err = dumpDatabases(ctx, conf, connectPool, writer); err != nil {
 			return err
 		}
 	} else {
-		if err = dumpSql(ctx, conf, pool, writer); err != nil {
+		if err = dumpSql(ctx, conf, connectPool, writer); err != nil {
 			return err
 		}
 	}
 
 	m.recordFinishTime(time.Now())
-
-	return conCtrl.TearDown()
+	return nil
 }
 
-func dumpDatabases(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error {
+func dumpDatabases(ctx context.Context, conf *Config, connectPool *connectionsPool, writer Writer) error {
 	allTables := conf.Tables
+	var g errgroup.Group
 	for dbName, tables := range allTables {
-		createDatabaseSQL, err := ShowCreateDatabase(db, dbName)
+		conn := connectPool.getConn()
+		createDatabaseSQL, err := ShowCreateDatabase(conn, dbName)
+		connectPool.releaseConn(conn)
 		if err != nil {
 			return err
 		}
@@ -183,24 +207,35 @@ func dumpDatabases(ctx context.Context, conf *Config, db *sql.DB, writer Writer)
 		if len(tables) == 0 {
 			continue
 		}
-		rateLimit := newRateLimit(conf.Threads)
-		var g errgroup.Group
 		for _, table := range tables {
 			table := table
-			g.Go(func() error {
-				rateLimit.getToken()
-				defer rateLimit.putToken()
-				return dumpTable(ctx, conf, db, dbName, table, writer)
-			})
-		}
-		if err := g.Wait(); err != nil {
-			return err
+			conn := connectPool.getConn()
+			tableDataIRArray, err := dumpTable(ctx, conf, conn, dbName, table, writer)
+			if err != nil {
+				return err
+			}
+			connectPool.releaseConn(conn)
+			for _, tableIR := range tableDataIRArray {
+				tableIR := tableIR
+				g.Go(func() error {
+					conn := connectPool.getConn()
+					defer connectPool.releaseConn(conn)
+					err := tableIR.Start(ctx, conn)
+					if err != nil {
+						return err
+					}
+					return writer.WriteTableData(ctx, tableIR)
+				})
+			}
 		}
 	}
+	if err := g.Wait(); err != nil {
+		return err
+	}
 	return nil
 }
 
-func prepareTableListToDump(conf *Config, pool *sql.DB) error {
+func prepareTableListToDump(conf *Config, pool *sql.Conn) error {
 	databases, err := prepareDumpingDatabases(conf, pool)
 	if err != nil {
 		return err
@@ -223,8 +258,10 @@ func prepareTableListToDump(conf *Config, pool *sql.DB) error {
 	return nil
 }
 
-func dumpSql(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error {
-	tableIR, err := SelectFromSql(conf, db)
+func dumpSql(ctx context.Context, conf *Config, connectPool *connectionsPool, writer Writer) error {
+	conn := connectPool.getConn()
+	tableIR, err := SelectFromSql(conf, conn)
+	connectPool.releaseConn(conn)
 	if err != nil {
 		return err
 	}
@@ -232,45 +269,45 @@ func dumpSql(ctx context.Context, conf *Config, db *sql.DB, writer Writer) error
 	return writer.WriteTableData(ctx, tableIR)
 }
 
-func dumpTable(ctx context.Context, conf *Config, db *sql.DB, dbName string, table *TableInfo, writer Writer) error {
+func dumpTable(ctx context.Context, conf *Config, db *sql.Conn, dbName string, table *TableInfo, writer Writer) ([]TableDataIR, error) {
 	tableName := table.Name
 	if !conf.NoSchemas {
 		if table.Type == TableTypeView {
 			viewName := table.Name
 			createViewSQL, err := ShowCreateView(db, dbName, viewName)
 			if err != nil {
-				return err
+				return nil, err
 			}
-			return writer.WriteTableMeta(ctx, dbName, viewName, createViewSQL)
+			return nil, writer.WriteTableMeta(ctx, dbName, viewName, createViewSQL)
 		}
 		createTableSQL, err := ShowCreateTable(db, dbName, tableName)
 		if err != nil {
-			return err
+			return nil, err
 		}
 		if err := writer.WriteTableMeta(ctx, dbName, tableName, createTableSQL); err != nil {
-			return err
+			return nil, err
 		}
 	}
 	// Do not dump table data and return nil
 	if conf.NoData {
-		return nil
+		return nil, nil
 	}
 
 	if conf.Rows != UnspecifiedSize {
-		finished, err := concurrentDumpTable(ctx, writer, conf, db, dbName, tableName)
+		finished, chunksIterArray, err := concurrentDumpTable(ctx, conf, db, dbName, tableName)
 		if err != nil || finished {
-			return err
+			return chunksIterArray, err
 		}
 	}
 	tableIR, err := SelectAllFromTable(conf, db, dbName, tableName)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
-	return writer.WriteTableData(ctx, tableIR)
+	return []TableDataIR{tableIR}, nil
 }
 
-func concurrentDumpTable(ctx context.Context, writer Writer, conf *Config, db *sql.DB, dbName string, tableName string) (bool, error) {
+func concurrentDumpTable(ctx context.Context, conf *Config, db *sql.Conn, dbName string, tableName string) (bool, []TableDataIR, error) {
 	// try dump table concurrently by split table to chunks
 	chunksIterCh := make(chan TableDataIR, defaultDumpThreads)
 	errCh := make(chan error, defaultDumpThreads)
@@ -279,6 +316,7 @@ func concurrentDumpTable(ctx context.Context, writer Writer, conf *Config, db *s
 	ctx1, cancel1 := context.WithCancel(ctx)
 	defer cancel1()
 	var g errgroup.Group
+	chunksIterArray := make([]TableDataIR, 0)
 	g.Go(func() error {
 		splitTableDataIntoChunks(ctx1, chunksIterCh, errCh, linear, dbName, tableName, db, conf)
 		return nil
@@ -288,24 +326,22 @@ Loop:
 	for {
 		select {
 		case <-ctx.Done():
-			return true, nil
+			return true, chunksIterArray, nil
 		case <-linear:
-			return false, nil
+			return false, chunksIterArray, nil
 		case chunksIter, ok := <-chunksIterCh:
 			if !ok {
 				break Loop
 			}
-			g.Go(func() error {
-				return writer.WriteTableData(ctx, chunksIter)
-			})
+			chunksIterArray = append(chunksIterArray, chunksIter)
 		case err := <-errCh:
-			return false, err
+			return false, chunksIterArray, err
 		}
 	}
 	if err := g.Wait(); err != nil {
-		return true, err
+		return true, chunksIterArray, err
 	}
-	return true, nil
+	return true, chunksIterArray, nil
 }
 
 func updateServiceSafePoint(ctx context.Context, pdClient pd.Client, ttl int64, snapshotTS uint64) {
diff --git a/dumpling/v4/export/dump_test.go b/dumpling/v4/export/dump_test.go
index afa7163a..18055eac 100644
--- a/dumpling/v4/export/dump_test.go
+++ b/dumpling/v4/export/dump_test.go
@@ -2,6 +2,7 @@ package export
 
 import (
 	"context"
+	"database/sql"
 	"fmt"
 	"strconv"
 
@@ -27,6 +28,14 @@ func newMockWriter() *mockWriter {
 	}
 }
 
+func newMockConnectPool(c *C, db *sql.DB) *connectionsPool {
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
+	connectPool := &connectionsPool{conns: make(chan *sql.Conn, 1)}
+	connectPool.releaseConn(conn)
+	return connectPool
+}
+
 func (m *mockWriter) WriteDatabaseMeta(ctx context.Context, db, createSQL string) error {
 	m.databaseMeta[db] = createSQL
 	return nil
@@ -64,7 +73,8 @@ func (s *testDumpSuite) TestDumpDatabase(c *C) {
 	mock.ExpectQuery("SELECT (.) FROM `test`.`t`").WillReturnRows(rows)
 
 	mockWriter := newMockWriter()
-	err = dumpDatabases(context.Background(), mockConfig, db, mockWriter)
+	connectPool := newMockConnectPool(c, db)
+	err = dumpDatabases(context.Background(), mockConfig, connectPool, mockWriter)
 	c.Assert(err, IsNil)
 
 	c.Assert(len(mockWriter.databaseMeta), Equals, 1)
@@ -78,6 +88,8 @@ func (s *testDumpSuite) TestDumpTable(c *C) {
 	mockConfig.SortByPk = false
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	showCreateTableResult := "CREATE TABLE t (a INT)"
 	rows := mock.NewRows([]string{"Table", "Create Table"}).AddRow("t", showCreateTableResult)
@@ -90,8 +102,13 @@ func (s *testDumpSuite) TestDumpTable(c *C) {
 	mock.ExpectQuery("SELECT (.) FROM `test`.`t`").WillReturnRows(rows)
 
 	mockWriter := newMockWriter()
-	err = dumpTable(context.Background(), mockConfig, db, "test", &TableInfo{Name: "t"}, mockWriter)
+	ctx := context.Background()
+	tableIRArray, err := dumpTable(ctx, mockConfig, conn, "test", &TableInfo{Name: "t"}, mockWriter)
 	c.Assert(err, IsNil)
+	for _, tableIR := range tableIRArray {
+		c.Assert(tableIR.Start(ctx, conn), IsNil)
+		c.Assert(mockWriter.WriteTableData(ctx, tableIR), IsNil)
+	}
 
 	c.Assert(mockWriter.tableMeta["test.t"], Equals, showCreateTableResult)
 	c.Assert(len(mockWriter.tableData), Equals, 1)
@@ -121,6 +138,8 @@ func (s *testDumpSuite) TestDumpTableWhereClause(c *C) {
 	mockConfig.Where = "a > 3 and a < 9"
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	showCreateTableResult := "CREATE TABLE t (a INT)"
 	rows := mock.NewRows([]string{"Table", "Create Table"}).AddRow("t", showCreateTableResult)
@@ -137,8 +156,13 @@ func (s *testDumpSuite) TestDumpTableWhereClause(c *C) {
 	mock.ExpectQuery("SELECT (.) FROM `test`.`t` WHERE a > 3 and a < 9").WillReturnRows(rows)
 
 	mockWriter := newMockWriter()
-	err = dumpTable(context.Background(), mockConfig, db, "test", &TableInfo{Name: "t"}, mockWriter)
+	ctx := context.Background()
+	tableIRArray, err := dumpTable(ctx, mockConfig, conn, "test", &TableInfo{Name: "t"}, mockWriter)
 	c.Assert(err, IsNil)
+	for _, tableIR := range tableIRArray {
+		c.Assert(tableIR.Start(ctx, conn), IsNil)
+		c.Assert(mockWriter.WriteTableData(ctx, tableIR), IsNil)
+	}
 
 	c.Assert(mockWriter.tableMeta["test.t"], Equals, showCreateTableResult)
 	c.Assert(len(mockWriter.tableData), Equals, 1)
diff --git a/dumpling/v4/export/ir.go b/dumpling/v4/export/ir.go
index 71a71a3f..d7dc33e7 100644
--- a/dumpling/v4/export/ir.go
+++ b/dumpling/v4/export/ir.go
@@ -2,11 +2,13 @@ package export
 
 import (
 	"bytes"
+	"context"
 	"database/sql"
 )
 
 // TableDataIR is table data intermediate representation.
 type TableDataIR interface {
+	Start(context.Context, *sql.Conn) error
 	DatabaseName() string
 	TableName() string
 	ChunkIndex() int
diff --git a/dumpling/v4/export/ir_impl.go b/dumpling/v4/export/ir_impl.go
index 321ca549..5da9a54a 100644
--- a/dumpling/v4/export/ir_impl.go
+++ b/dumpling/v4/export/ir_impl.go
@@ -146,6 +146,7 @@ func (m *stringIter) HasNext() bool {
 type tableData struct {
 	database        string
 	table           string
+	query           string
 	chunkIndex      int
 	rows            *sql.Rows
 	colTypes        []*sql.ColumnType
@@ -154,6 +155,15 @@ type tableData struct {
 	escapeBackslash bool
 }
 
+func (td *tableData) Start(ctx context.Context, conn *sql.Conn) error {
+	rows, err := conn.QueryContext(ctx, td.query)
+	if err != nil {
+		return err
+	}
+	td.rows = rows
+	return nil
+}
+
 func (td *tableData) ColumnTypes() []string {
 	colTypes := make([]string, len(td.colTypes))
 	for i, ct := range td.colTypes {
@@ -241,7 +251,7 @@ func splitTableDataIntoChunks(
 	tableDataIRCh chan TableDataIR,
 	errCh chan error,
 	linear chan struct{},
-	dbName, tableName string, db *sql.DB, conf *Config) {
+	dbName, tableName string, db *sql.Conn, conf *Config) {
 	field, err := pickupPossibleField(dbName, tableName, db, conf)
 	if err != nil {
 		errCh <- withStack(err)
@@ -263,7 +273,7 @@ func splitTableDataIntoChunks(
 
 	var smin sql.NullString
 	var smax sql.NullString
-	row := db.QueryRow(query)
+	row := db.QueryRowContext(ctx, query)
 	err = row.Scan(&smin, &smax)
 	if err != nil {
 		log.Error("split chunks - get max min failed", zap.String("query", query), zap.Error(err))
@@ -329,11 +339,6 @@ LOOP:
 		chunkIndex += 1
 		where := fmt.Sprintf("%s(`%s` >= %d AND `%s` < %d)", nullValueCondition, escapeString(field), cutoff, escapeString(field), cutoff+estimatedStep)
 		query = buildSelectQuery(dbName, tableName, selectedField, buildWhereCondition(conf, where), orderByClause)
-		rows, err := db.Query(query)
-		if err != nil {
-			errCh <- errors.WithMessage(err, query)
-			return
-		}
 		if len(nullValueCondition) > 0 {
 			nullValueCondition = ""
 		}
@@ -341,7 +346,7 @@ LOOP:
 		td := &tableData{
 			database:      dbName,
 			table:         tableName,
-			rows:          rows,
+			query:         query,
 			chunkIndex:    chunkIndex,
 			colTypes:      colTypes,
 			selectedField: selectedField,
diff --git a/dumpling/v4/export/metadata.go b/dumpling/v4/export/metadata.go
index 36e14af4..63d5d8c8 100644
--- a/dumpling/v4/export/metadata.go
+++ b/dumpling/v4/export/metadata.go
@@ -2,6 +2,7 @@ package export
 
 import (
 	"bytes"
+	"context"
 	"database/sql"
 	"errors"
 	"fmt"
@@ -46,7 +47,7 @@ func (m *globalMetadata) recordFinishTime(t time.Time) {
 	m.buffer.WriteString("Finished dump at: " + t.Format(metadataTimeLayout) + "\n")
 }
 
-func (m *globalMetadata) recordGlobalMetaData(db *sql.DB, serverType ServerType) error {
+func (m *globalMetadata) recordGlobalMetaData(db *sql.Conn, serverType ServerType) error {
 	// get master status info
 	m.buffer.WriteString("SHOW MASTER STATUS:\n")
 	switch serverType {
@@ -107,7 +108,7 @@ func (m *globalMetadata) recordGlobalMetaData(db *sql.DB, serverType ServerType)
 			m.buffer.WriteString("\tPos: " + pos + "\n")
 		}
 		var gtidSet string
-		err = db.QueryRow("SELECT @@global.gtid_binlog_pos").Scan(&gtidSet)
+		err = db.QueryRowContext(context.Background(), "SELECT @@global.gtid_binlog_pos").Scan(&gtidSet)
 		if err != nil {
 			return err
 		}
diff --git a/dumpling/v4/export/metadata_test.go b/dumpling/v4/export/metadata_test.go
index e412ed96..0918f38d 100644
--- a/dumpling/v4/export/metadata_test.go
+++ b/dumpling/v4/export/metadata_test.go
@@ -1,6 +1,7 @@
 package export
 
 import (
+	"context"
 	"fmt"
 	"path"
 
@@ -16,6 +17,8 @@ func (s *testMetaDataSuite) TestMysqlMetaData(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	logFile := "ON.000001"
 	pos := "7502"
@@ -29,7 +32,7 @@ func (s *testMetaDataSuite) TestMysqlMetaData(c *C) {
 
 	testFilePath := "/test"
 	m := newGlobalMetadata(testFilePath)
-	c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil)
+	c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil)
 	c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath))
 
 	c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+
@@ -43,6 +46,8 @@ func (s *testMetaDataSuite) TestMysqlWithFollowersMetaData(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	logFile := "ON.000001"
 	pos := "7502"
@@ -57,7 +62,7 @@ func (s *testMetaDataSuite) TestMysqlWithFollowersMetaData(c *C) {
 
 	testFilePath := "/test"
 	m := newGlobalMetadata(testFilePath)
-	c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil)
+	c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil)
 	c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath))
 
 	c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+
@@ -76,6 +81,8 @@ func (s *testMetaDataSuite) TestMariaDBMetaData(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	logFile := "mariadb-bin.000016"
 	pos := "475"
@@ -89,7 +96,7 @@ func (s *testMetaDataSuite) TestMariaDBMetaData(c *C) {
 	mock.ExpectQuery("SHOW SLAVE STATUS").WillReturnRows(rows)
 	testFilePath := "/test"
 	m := newGlobalMetadata(testFilePath)
-	c.Assert(m.recordGlobalMetaData(db, ServerTypeMariaDB), IsNil)
+	c.Assert(m.recordGlobalMetaData(conn, ServerTypeMariaDB), IsNil)
 	c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath))
 
 	c.Assert(mock.ExpectationsWereMet(), IsNil)
@@ -99,6 +106,8 @@ func (s *testMetaDataSuite) TestMariaDBWithFollowersMetaData(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	logFile := "ON.000001"
 	pos := "7502"
@@ -116,7 +125,7 @@ func (s *testMetaDataSuite) TestMariaDBWithFollowersMetaData(c *C) {
 
 	testFilePath := "/test"
 	m := newGlobalMetadata(testFilePath)
-	c.Assert(m.recordGlobalMetaData(db, ServerTypeMySQL), IsNil)
+	c.Assert(m.recordGlobalMetaData(conn, ServerTypeMySQL), IsNil)
 	c.Assert(m.filePath, Equals, path.Join(testFilePath, metadataPath))
 
 	c.Assert(m.buffer.String(), Equals, "SHOW MASTER STATUS:\n"+
diff --git a/dumpling/v4/export/prepare.go b/dumpling/v4/export/prepare.go
index 3d123c73..91ea562b 100644
--- a/dumpling/v4/export/prepare.go
+++ b/dumpling/v4/export/prepare.go
@@ -64,7 +64,7 @@ func detectServerInfo(db *sql.DB) (ServerInfo, error) {
 	return ParseServerInfo(versionStr), nil
 }
 
-func prepareDumpingDatabases(conf *Config, db *sql.DB) ([]string, error) {
+func prepareDumpingDatabases(conf *Config, db *sql.Conn) ([]string, error) {
 	databases, err := ShowDatabases(db)
 	if len(conf.Databases) == 0 {
 		return databases, err
@@ -86,12 +86,12 @@ func prepareDumpingDatabases(conf *Config, db *sql.DB) ([]string, error) {
 	}
 }
 
-func listAllTables(db *sql.DB, databaseNames []string) (DatabaseTables, error) {
+func listAllTables(db *sql.Conn, databaseNames []string) (DatabaseTables, error) {
 	log.Debug("list all the tables")
 	return ListAllDatabasesTables(db, databaseNames, TableTypeBase)
 }
 
-func listAllViews(db *sql.DB, databaseNames []string) (DatabaseTables, error) {
+func listAllViews(db *sql.Conn, databaseNames []string) (DatabaseTables, error) {
 	log.Debug("list all the views")
 	return ListAllDatabasesTables(db, databaseNames, TableTypeView)
 }
diff --git a/dumpling/v4/export/prepare_test.go b/dumpling/v4/export/prepare_test.go
index 7233e925..3fc038b5 100644
--- a/dumpling/v4/export/prepare_test.go
+++ b/dumpling/v4/export/prepare_test.go
@@ -1,6 +1,7 @@
 package export
 
 import (
+	"context"
 	"fmt"
 
 	"github.com/DATA-DOG/go-sqlmock"
@@ -15,6 +16,8 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	rows := sqlmock.NewRows([]string{"Database"}).
 		AddRow("db1").
@@ -24,7 +27,7 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) {
 	mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows)
 	conf := DefaultConfig()
 	conf.Databases = []string{"db1", "db2", "db3"}
-	result, err := prepareDumpingDatabases(conf, db)
+	result, err := prepareDumpingDatabases(conf, conn)
 	c.Assert(err, IsNil)
 	c.Assert(result, DeepEquals, []string{"db1", "db2", "db3"})
 
@@ -33,12 +36,12 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) {
 		AddRow("db1").
 		AddRow("db2")
 	mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows)
-	result, err = prepareDumpingDatabases(conf, db)
+	result, err = prepareDumpingDatabases(conf, conn)
 	c.Assert(err, IsNil)
 	c.Assert(result, DeepEquals, []string{"db1", "db2"})
 
 	mock.ExpectQuery("SHOW DATABASES").WillReturnError(fmt.Errorf("err"))
-	_, err = prepareDumpingDatabases(conf, db)
+	_, err = prepareDumpingDatabases(conf, conn)
 	c.Assert(err, NotNil)
 
 	rows = sqlmock.NewRows([]string{"Database"}).
@@ -48,7 +51,7 @@ func (s *testPrepareSuite) TestPrepareDumpingDatabases(c *C) {
 		AddRow("db5")
 	mock.ExpectQuery("SHOW DATABASES").WillReturnRows(rows)
 	conf.Databases = []string{"db1", "db2", "db4", "db6"}
-	_, err = prepareDumpingDatabases(conf, db)
+	_, err = prepareDumpingDatabases(conf, conn)
 	c.Assert(err, ErrorMatches, `Unknown databases \[db4,db6\]`)
 	c.Assert(mock.ExpectationsWereMet(), IsNil)
 }
@@ -57,6 +60,8 @@ func (s *testPrepareSuite) TestListAllTables(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	data := NewDatabaseTables().
 		AppendTables("db1", "t1", "t2").
@@ -78,7 +83,7 @@ func (s *testPrepareSuite) TestListAllTables(c *C) {
 	query := "SELECT table_schema,table_name FROM information_schema.tables WHERE table_type = (.*)"
 	mock.ExpectQuery(query).WillReturnRows(rows)
 
-	tables, err := listAllTables(db, dbNames)
+	tables, err := listAllTables(conn, dbNames)
 	c.Assert(err, IsNil)
 
 	for d, t := range tables {
@@ -96,7 +101,7 @@ func (s *testPrepareSuite) TestListAllTables(c *C) {
 		AppendViews("db", "t2")
 	query = "SELECT table_schema,table_name FROM information_schema.tables WHERE table_type = (.*)"
 	mock.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"table_schema", "table_name"}).AddRow("db", "t2"))
-	tables, err = listAllViews(db, []string{"db"})
+	tables, err = listAllViews(conn, []string{"db"})
 	c.Assert(err, IsNil)
 	c.Assert(len(tables), Equals, 1)
 	c.Assert(len(tables["db"]), Equals, 1)
diff --git a/dumpling/v4/export/ratelimit.go b/dumpling/v4/export/ratelimit.go
deleted file mode 100644
index aafb5695..00000000
--- a/dumpling/v4/export/ratelimit.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package export
-
-type rateLimit struct {
-	token chan struct{}
-}
-
-func newRateLimit(n int) *rateLimit {
-	return &rateLimit{
-		token: make(chan struct{}, n),
-	}
-}
-
-func (r *rateLimit) getToken() {
-	r.token <- struct{}{}
-}
-
-func (r *rateLimit) putToken() {
-	select {
-	case <-r.token:
-	default:
-		panic("put a redundant token")
-	}
-}
diff --git a/dumpling/v4/export/sql.go b/dumpling/v4/export/sql.go
index 4a9396ea..a5bfbfd0 100644
--- a/dumpling/v4/export/sql.go
+++ b/dumpling/v4/export/sql.go
@@ -1,6 +1,7 @@
 package export
 
 import (
+	"context"
 	"database/sql"
 	"fmt"
 	"net/url"
@@ -13,7 +14,7 @@ import (
 	"github.com/pingcap/dumpling/v4/log"
 )
 
-func ShowDatabases(db *sql.DB) ([]string, error) {
+func ShowDatabases(db *sql.Conn) ([]string, error) {
 	var res oneStrColumnTable
 	if err := simpleQuery(db, "SHOW DATABASES", res.handleOneRow); err != nil {
 		return nil, err
@@ -22,7 +23,7 @@ func ShowDatabases(db *sql.DB) ([]string, error) {
 }
 
 // ShowTables shows the tables of a database, the caller should use the correct database.
-func ShowTables(db *sql.DB) ([]string, error) {
+func ShowTables(db *sql.Conn) ([]string, error) {
 	var res oneStrColumnTable
 	if err := simpleQuery(db, "SHOW TABLES", res.handleOneRow); err != nil {
 		return nil, err
@@ -30,7 +31,7 @@ func ShowTables(db *sql.DB) ([]string, error) {
 	return res.data, nil
 }
 
-func ShowCreateDatabase(db *sql.DB, database string) (string, error) {
+func ShowCreateDatabase(db *sql.Conn, database string) (string, error) {
 	var oneRow [2]string
 	handleOneRow := func(rows *sql.Rows) error {
 		return rows.Scan(&oneRow[0], &oneRow[1])
@@ -43,7 +44,7 @@ func ShowCreateDatabase(db *sql.DB, database string) (string, error) {
 	return oneRow[1], nil
 }
 
-func ShowCreateTable(db *sql.DB, database, table string) (string, error) {
+func ShowCreateTable(db *sql.Conn, database, table string) (string, error) {
 	var oneRow [2]string
 	handleOneRow := func(rows *sql.Rows) error {
 		return rows.Scan(&oneRow[0], &oneRow[1])
@@ -56,7 +57,7 @@ func ShowCreateTable(db *sql.DB, database, table string) (string, error) {
 	return oneRow[1], nil
 }
 
-func ShowCreateView(db *sql.DB, database, view string) (string, error) {
+func ShowCreateView(db *sql.Conn, database, view string) (string, error) {
 	var oneRow [4]string
 	handleOneRow := func(rows *sql.Rows) error {
 		return rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3])
@@ -69,16 +70,7 @@ func ShowCreateView(db *sql.DB, database, view string) (string, error) {
 	return oneRow[1], nil
 }
 
-func ListAllTables(db *sql.DB, database string) ([]string, error) {
-	var tables oneStrColumnTable
-	const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'BASE TABLE'"
-	if err := simpleQueryWithArgs(db, tables.handleOneRow, query, database); err != nil {
-		return nil, errors.WithMessage(err, query)
-	}
-	return tables.data, nil
-}
-
-func ListAllDatabasesTables(db *sql.DB, databaseNames []string, tableType TableType) (DatabaseTables, error) {
+func ListAllDatabasesTables(db *sql.Conn, databaseNames []string, tableType TableType) (DatabaseTables, error) {
 	var tableTypeStr string
 	switch tableType {
 	case TableTypeBase:
@@ -112,7 +104,16 @@ func ListAllDatabasesTables(db *sql.DB, databaseNames []string, tableType TableT
 	return dbTables, nil
 }
 
-func ListAllViews(db *sql.DB, database string) ([]string, error) {
+func ListAllTables(db *sql.Conn, database string) ([]string, error) {
+	var tables oneStrColumnTable
+	const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'BASE TABLE'"
+	if err := simpleQueryWithArgs(db, tables.handleOneRow, query, database); err != nil {
+		return nil, errors.WithMessage(err, query)
+	}
+	return tables.data, nil
+}
+
+func ListAllViews(db *sql.Conn, database string) ([]string, error) {
 	var views oneStrColumnTable
 	const query = "SELECT table_name FROM information_schema.tables WHERE table_schema = ? and table_type = 'VIEW'"
 	if err := simpleQueryWithArgs(db, views.handleOneRow, query, database); err != nil {
@@ -123,17 +124,15 @@ func ListAllViews(db *sql.DB, database string) ([]string, error) {
 
 func SelectVersion(db *sql.DB) (string, error) {
 	var versionInfo string
-	handleOneRow := func(rows *sql.Rows) error {
-		return rows.Scan(&versionInfo)
-	}
-	err := simpleQuery(db, "SELECT version()", handleOneRow)
+	row := db.QueryRow("SELECT version()")
+	err := row.Scan(&versionInfo)
 	if err != nil {
 		return "", withStack(err)
 	}
 	return versionInfo, nil
 }
 
-func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (TableDataIR, error) {
+func SelectAllFromTable(conf *Config, db *sql.Conn, database, table string) (TableDataIR, error) {
 	selectedField, err := buildSelectField(db, database, table)
 	if err != nil {
 		return nil, err
@@ -150,15 +149,11 @@ func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (Table
 	}
 
 	query := buildSelectQuery(database, table, selectedField, buildWhereCondition(conf, ""), orderByClause)
-	rows, err := db.Query(query)
-	if err != nil {
-		return nil, withStack(errors.WithMessage(err, query))
-	}
 
 	return &tableData{
 		database:        database,
 		table:           table,
-		rows:            rows,
+		query:           query,
 		colTypes:        colTypes,
 		selectedField:   selectedField,
 		escapeBackslash: conf.EscapeBackslash,
@@ -168,8 +163,8 @@ func SelectAllFromTable(conf *Config, db *sql.DB, database, table string) (Table
 	}, nil
 }
 
-func SelectFromSql(conf *Config, db *sql.DB) (TableDataIR, error) {
-	rows, err := db.Query(conf.Sql)
+func SelectFromSql(conf *Config, db *sql.Conn) (TableDataIR, error) {
+	rows, err := db.QueryContext(context.Background(), conf.Sql)
 	if err != nil {
 		return nil, withStack(errors.WithMessage(err, conf.Sql))
 	}
@@ -220,7 +215,7 @@ func buildSelectQuery(database, table string, fields string, where string, order
 	return query.String()
 }
 
-func buildOrderByClause(conf *Config, db *sql.DB, database, table string) (string, error) {
+func buildOrderByClause(conf *Config, db *sql.Conn, database, table string) (string, error) {
 	if !conf.SortByPk {
 		return "", nil
 	}
@@ -246,10 +241,10 @@ func buildOrderByClause(conf *Config, db *sql.DB, database, table string) (strin
 	return "", nil
 }
 
-func SelectTiDBRowID(db *sql.DB, database, table string) (bool, error) {
+func SelectTiDBRowID(db *sql.Conn, database, table string) (bool, error) {
 	const errBadFieldCode = 1054
 	tiDBRowIDQuery := fmt.Sprintf("SELECT _tidb_rowid from `%s`.`%s` LIMIT 0", escapeString(database), escapeString(table))
-	_, err := db.Exec(tiDBRowIDQuery)
+	_, err := db.ExecContext(context.Background(), tiDBRowIDQuery)
 	if err != nil {
 		errMsg := strings.ToLower(err.Error())
 		if strings.Contains(errMsg, fmt.Sprintf("%d", errBadFieldCode)) {
@@ -260,9 +255,9 @@ func SelectTiDBRowID(db *sql.DB, database, table string) (bool, error) {
 	return true, nil
 }
 
-func GetColumnTypes(db *sql.DB, fields, database, table string) ([]*sql.ColumnType, error) {
+func GetColumnTypes(db *sql.Conn, fields, database, table string) ([]*sql.ColumnType, error) {
 	query := fmt.Sprintf("SELECT %s FROM `%s`.`%s` LIMIT 1", fields, escapeString(database), escapeString(table))
-	rows, err := db.Query(query)
+	rows, err := db.QueryContext(context.Background(), query)
 	if err != nil {
 		return nil, withStack(errors.WithMessage(err, query))
 	}
@@ -270,11 +265,11 @@ func GetColumnTypes(db *sql.DB, fields, database, table string) ([]*sql.ColumnTy
 	return rows.ColumnTypes()
 }
 
-func GetPrimaryKeyName(db *sql.DB, database, table string) (string, error) {
+func GetPrimaryKeyName(db *sql.Conn, database, table string) (string, error) {
 	priKeyQuery := "SELECT column_name FROM information_schema.columns " +
 		"WHERE table_schema = ? AND table_name = ? AND column_key = 'PRI';"
 	var colName string
-	row := db.QueryRow(priKeyQuery, database, table)
+	row := db.QueryRowContext(context.Background(), priKeyQuery, database, table)
 	if err := row.Scan(&colName); err != nil {
 		if err == sql.ErrNoRows {
 			return "", nil
@@ -285,11 +280,11 @@ func GetPrimaryKeyName(db *sql.DB, database, table string) (string, error) {
 	return colName, nil
 }
 
-func GetUniqueIndexName(db *sql.DB, database, table string) (string, error) {
+func GetUniqueIndexName(db *sql.Conn, database, table string) (string, error) {
 	uniKeyQuery := "SELECT column_name FROM information_schema.columns " +
 		"WHERE table_schema = ? AND table_name = ? AND column_key = 'UNI';"
 	var colName string
-	row := db.QueryRow(uniKeyQuery, database, table)
+	row := db.QueryRowContext(context.Background(), uniKeyQuery, database, table)
 	if err := row.Scan(&colName); err != nil {
 		if err == sql.ErrNoRows {
 			return "", nil
@@ -320,7 +315,7 @@ func UseDatabase(db *sql.DB, databaseName string) error {
 	return withStack(err)
 }
 
-func ShowMasterStatus(db *sql.DB, fieldNum int) ([]string, error) {
+func ShowMasterStatus(db *sql.Conn, fieldNum int) ([]string, error) {
 	oneRow := make([]string, fieldNum)
 	addr := make([]interface{}, fieldNum)
 	for i := range oneRow {
@@ -387,17 +382,15 @@ func GetTiDBDDLIDs(db *sql.DB) ([]string, error) {
 
 func CheckTiDBWithTiKV(db *sql.DB) (bool, error) {
 	var count int
-	handleOneRow := func(rows *sql.Rows) error {
-		return rows.Scan(&count)
-	}
-	err := simpleQuery(db, "SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'", handleOneRow)
+	row := db.QueryRow("SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'")
+	err := row.Scan(&count)
 	if err != nil {
 		return false, err
 	}
 	return count > 0, nil
 }
 
-func getSnapshot(db *sql.DB) (string, error) {
+func getSnapshot(db *sql.Conn) (string, error) {
 	str, err := ShowMasterStatus(db, showMasterStatusFieldNum)
 	if err != nil {
 		return "", err
@@ -445,9 +438,25 @@ func resetDBWithSessionParams(db *sql.DB, dsn string, params map[string]interfac
 	return newDB, nil
 }
 
-func buildSelectField(db *sql.DB, dbName, tableName string) (string, error) {
+func createConnWithConsistency(ctx context.Context, db *sql.DB) (*sql.Conn, error) {
+	conn, err := db.Conn(ctx)
+	if err != nil {
+		return nil, err
+	}
+	_, err = conn.ExecContext(ctx, "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
+	if err != nil {
+		return nil, err
+	}
+	_, err = conn.ExecContext(ctx, "START TRANSACTION /*!40108 WITH CONSISTENT SNAPSHOT */")
+	if err != nil {
+		return nil, err
+	}
+	return conn, err
+}
+
+func buildSelectField(db *sql.Conn, dbName, tableName string) (string, error) {
 	query := `SELECT COLUMN_NAME,EXTRA FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA=? AND TABLE_NAME=?;`
-	rows, err := db.Query(query, dbName, tableName)
+	rows, err := db.QueryContext(context.Background(), query, dbName, tableName)
 	if err != nil {
 		return "", err
 	}
@@ -488,12 +497,12 @@ func (o *oneStrColumnTable) handleOneRow(rows *sql.Rows) error {
 	return nil
 }
 
-func simpleQuery(db *sql.DB, sql string, handleOneRow func(*sql.Rows) error) error {
-	return simpleQueryWithArgs(db, handleOneRow, sql)
+func simpleQuery(conn *sql.Conn, sql string, handleOneRow func(*sql.Rows) error) error {
+	return simpleQueryWithArgs(conn, handleOneRow, sql)
 }
 
-func simpleQueryWithArgs(db *sql.DB, handleOneRow func(*sql.Rows) error, sql string, args ...interface{}) error {
-	rows, err := db.Query(sql, args...)
+func simpleQueryWithArgs(conn *sql.Conn, handleOneRow func(*sql.Rows) error, sql string, args ...interface{}) error {
+	rows, err := conn.QueryContext(context.Background(), sql, args...)
 	if err != nil {
 		return withStack(err)
 	}
@@ -507,7 +516,7 @@ func simpleQueryWithArgs(db *sql.DB, handleOneRow func(*sql.Rows) error, sql str
 	return rows.Err()
 }
 
-func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (string, error) {
+func pickupPossibleField(dbName, tableName string, db *sql.Conn, conf *Config) (string, error) {
 	// If detected server is TiDB, try using _tidb_rowid
 	if conf.ServerInfo.ServerType == ServerTypeTiDB {
 		ok, err := SelectTiDBRowID(db, dbName, tableName)
@@ -539,7 +548,7 @@ func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (st
 	query := "SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " +
 		"WHERE TABLE_NAME = ? AND COLUMN_NAME = ?"
 	var fieldType string
-	row := db.QueryRow(query, tableName, fieldName)
+	row := db.QueryRowContext(context.Background(), query, tableName, fieldName)
 	err = row.Scan(&fieldType)
 	if err != nil {
 		if err == sql.ErrNoRows {
@@ -555,7 +564,7 @@ func pickupPossibleField(dbName, tableName string, db *sql.DB, conf *Config) (st
 	return "", nil
 }
 
-func estimateCount(dbName, tableName string, db *sql.DB, field string, conf *Config) uint64 {
+func estimateCount(dbName, tableName string, db *sql.Conn, field string, conf *Config) uint64 {
 	query := fmt.Sprintf("EXPLAIN SELECT `%s` FROM `%s`.`%s`", field, escapeString(dbName), escapeString(tableName))
 
 	if conf.Where != "" {
@@ -592,13 +601,14 @@ func estimateCount(dbName, tableName string, db *sql.DB, field string, conf *Con
 	return 0
 }
 
-func detectEstimateRows(db *sql.DB, query string, fieldNames []string) uint64 {
-	row, err := db.Query(query)
+func detectEstimateRows(db *sql.Conn, query string, fieldNames []string) uint64 {
+	row, err := db.QueryContext(context.Background(), query)
 	if err != nil {
 		log.Warn("can't execute query from db",
 			zap.String("query", query), zap.Error(err))
 		return 0
 	}
+	defer row.Close()
 	row.Next()
 	columns, _ := row.Columns()
 	addr := make([]interface{}, len(columns))
@@ -638,19 +648,14 @@ func parseSnapshotToTSO(pool *sql.DB, snapshot string) (uint64, error) {
 		return snapshotTS, nil
 	}
 	var tso sql.NullInt64
-	err = simpleQueryWithArgs(pool, func(rows *sql.Rows) error {
-		err := rows.Scan(&tso)
-		if err != nil {
-			return err
-		}
-		if !tso.Valid {
-			return fmt.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot)
-		}
-		return nil
-	}, "SELECT unix_timestamp(?)", snapshot)
+	row := pool.QueryRow("SELECT unix_timestamp(?)", snapshot)
+	err = row.Scan(&tso)
 	if err != nil {
 		return 0, withStack(err)
 	}
+	if !tso.Valid {
+		return 0, withStack(fmt.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot))
+	}
 	return (uint64(tso.Int64)<<18)*1000 + 1, nil
 }
 
diff --git a/dumpling/v4/export/sql_test.go b/dumpling/v4/export/sql_test.go
index 79bc2b10..ee02e534 100644
--- a/dumpling/v4/export/sql_test.go
+++ b/dumpling/v4/export/sql_test.go
@@ -1,6 +1,7 @@
 package export
 
 import (
+	"context"
 	"errors"
 
 	"github.com/DATA-DOG/go-sqlmock"
@@ -54,6 +55,8 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	mockConf := DefaultConfig()
 	mockConf.SortByPk = true
@@ -65,14 +68,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 	mock.ExpectExec("SELECT _tidb_rowid from `test`.`t`").
 		WillReturnResult(sqlmock.NewResult(0, 0))
 
-	orderByClause, err := buildOrderByClause(mockConf, db, "test", "t")
+	orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t")
 	c.Assert(err, IsNil)
 
 	mock.ExpectQuery("SELECT COLUMN_NAME").
 		WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 		WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-	selectedField, err := buildSelectField(db, "test", "t")
+	selectedField, err := buildSelectField(conn, "test", "t")
 	c.Assert(err, IsNil)
 	q := buildSelectQuery("test", "t", selectedField, "", orderByClause)
 	c.Assert(q, Equals, "SELECT * FROM `test`.`t` ORDER BY _tidb_rowid")
@@ -81,14 +84,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 	mock.ExpectExec("SELECT _tidb_rowid from `test`.`t`").
 		WillReturnError(errors.New(`1054, "Unknown column '_tidb_rowid' in 'field list'"`))
 
-	orderByClause, err = buildOrderByClause(mockConf, db, "test", "t")
+	orderByClause, err = buildOrderByClause(mockConf, conn, "test", "t")
 	c.Assert(err, IsNil)
 
 	mock.ExpectQuery("SELECT COLUMN_NAME").
 		WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 		WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-	selectedField, err = buildSelectField(db, "test", "t")
+	selectedField, err = buildSelectField(conn, "test", "t")
 	c.Assert(err, IsNil)
 	q = buildSelectQuery("test", "t", selectedField, "", orderByClause)
 	c.Assert(q, Equals, "SELECT * FROM `test`.`t`")
@@ -104,14 +107,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 		mock.ExpectQuery("SELECT column_name FROM information_schema.columns").
 			WithArgs("test", "t").
 			WillReturnRows(sqlmock.NewRows([]string{"column_name"}).AddRow("id"))
-		orderByClause, err := buildOrderByClause(mockConf, db, "test", "t")
+		orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t")
 		c.Assert(err, IsNil, cmt)
 
 		mock.ExpectQuery("SELECT COLUMN_NAME").
 			WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 			WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-		selectedField, err = buildSelectField(db, "test", "t")
+		selectedField, err = buildSelectField(conn, "test", "t")
 		c.Assert(err, IsNil)
 		q = buildSelectQuery("test", "t", selectedField, "", orderByClause)
 		c.Assert(q, Equals, "SELECT * FROM `test`.`t` ORDER BY `id`", cmt)
@@ -128,14 +131,14 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 			WithArgs("test", "t").
 			WillReturnRows(sqlmock.NewRows([]string{"column_name"}))
 
-		orderByClause, err := buildOrderByClause(mockConf, db, "test", "t")
+		orderByClause, err := buildOrderByClause(mockConf, conn, "test", "t")
 		c.Assert(err, IsNil, cmt)
 
 		mock.ExpectQuery("SELECT COLUMN_NAME").
 			WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 			WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-		selectedField, err = buildSelectField(db, "test", "t")
+		selectedField, err = buildSelectField(conn, "test", "t")
 		c.Assert(err, IsNil)
 		q := buildSelectQuery("test", "t", selectedField, "", orderByClause)
 		c.Assert(q, Equals, "SELECT * FROM `test`.`t`", cmt)
@@ -154,7 +157,7 @@ func (s *testDumpSuite) TestBuildSelectAllQuery(c *C) {
 			WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 			WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-		selectedField, err := buildSelectField(db, "test", "t")
+		selectedField, err := buildSelectField(conn, "test", "t")
 		c.Assert(err, IsNil)
 		q := buildSelectQuery("test", "t", selectedField, "", "")
 		c.Assert(q, Equals, "SELECT * FROM `test`.`t`", cmt)
@@ -166,13 +169,15 @@ func (s *testDumpSuite) TestBuildSelectField(c *C) {
 	db, mock, err := sqlmock.New()
 	c.Assert(err, IsNil)
 	defer db.Close()
+	conn, err := db.Conn(context.Background())
+	c.Assert(err, IsNil)
 
 	// generate columns not found
 	mock.ExpectQuery("SELECT COLUMN_NAME").
 		WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).
 		WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("id", ""))
 
-	selectedField, err := buildSelectField(db, "test", "t")
+	selectedField, err := buildSelectField(conn, "test", "t")
 	c.Assert(selectedField, Equals, "*")
 	c.Assert(err, IsNil)
 	c.Assert(mock.ExpectationsWereMet(), IsNil)
@@ -183,7 +188,7 @@ func (s *testDumpSuite) TestBuildSelectField(c *C) {
 		WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).
 			AddRow("id", "").AddRow("name", "").AddRow("quo`te", "").AddRow("generated", "VIRTUAL GENERATED"))
 
-	selectedField, err = buildSelectField(db, "test", "t")
+	selectedField, err = buildSelectField(conn, "test", "t")
 	c.Assert(selectedField, Equals, "`id`,`name`,`quo``te`")
 	c.Assert(err, IsNil)
 	c.Assert(mock.ExpectationsWereMet(), IsNil)
diff --git a/dumpling/v4/export/test_util.go b/dumpling/v4/export/test_util.go
index 1c625541..f4fee268 100644
--- a/dumpling/v4/export/test_util.go
+++ b/dumpling/v4/export/test_util.go
@@ -1,6 +1,7 @@
 package export
 
 import (
+	"context"
 	"database/sql"
 	"database/sql/driver"
 	"fmt"
@@ -87,6 +88,10 @@ type mockTableIR struct {
 	rowErr          error
 }
 
+func (m *mockTableIR) Start(ctx context.Context, conn *sql.Conn) error {
+	return nil
+}
+
 func (m *mockTableIR) DatabaseName() string {
 	return m.dbName
 }