diff --git a/drainer/syncer.go b/drainer/syncer.go index 998eb5590..62490c1f6 100644 --- a/drainer/syncer.go +++ b/drainer/syncer.go @@ -14,6 +14,7 @@ import ( "github.com/pingcap/tidb-binlog/drainer/checkpoint" "github.com/pingcap/tidb-binlog/drainer/executor" "github.com/pingcap/tidb-binlog/drainer/translator" + "github.com/pingcap/tidb-binlog/pkg/loader" pkgsql "github.com/pingcap/tidb-binlog/pkg/sql" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/store/tikv/oracle" @@ -55,7 +56,7 @@ type Syncer struct { filter *filter - causality *causality + causality *loader.Causality lastSyncTime time.Time } @@ -70,7 +71,7 @@ func NewSyncer(ctx context.Context, cp checkpoint.CheckPoint, cfg *SyncerConfig) syncer.ctx, syncer.cancel = context.WithCancel(ctx) syncer.initCommitTS = cp.TS() syncer.positions = make(map[string]int64) - syncer.causality = newCausality() + syncer.causality = loader.NewCausality() syncer.lastSyncTime = time.Now() syncer.filter = newFilter(formatIgnoreSchemas(cfg.IgnoreSchemas), cfg.DoDBs, cfg.DoTables) @@ -243,7 +244,7 @@ func (s *Syncer) addJob(job *job) { if wait { eventCounter.WithLabelValues("savepoint").Add(1) s.jobWg.Wait() - s.causality.reset() + s.causality.Reset() s.savePoint(job.commitTS) } } @@ -270,18 +271,18 @@ func (s *Syncer) resolveCasuality(keys []string) (string, error) { return keys[0], nil } - if s.causality.detectConflict(keys) { + if s.causality.DetectConflict(keys) { if err := s.flushJobs(); err != nil { return "", errors.Trace(err) } - s.causality.reset() + s.causality.Reset() } - if err := s.causality.add(keys); err != nil { + if err := s.causality.Add(keys); err != nil { return "", errors.Trace(err) } - return s.causality.get(keys[0]), nil + return s.causality.Get(keys[0]), nil } func (s *Syncer) flushJobs() error { diff --git a/go.mod b/go.mod index 872cfe8b9..4222e42d7 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/pingcap/tidb-binlog require ( github.com/BurntSushi/toml v0.3.1 + github.com/DATA-DOG/go-sqlmock v1.3.2 github.com/Shopify/sarama v1.18.0 github.com/Shopify/toxiproxy v2.1.3+incompatible // indirect github.com/beorn7/perks v0.0.0-20160229213445-3ac7bf7a47d1 // indirect @@ -62,15 +63,15 @@ require ( github.com/petar/GoLLRB v0.0.0-20130427215148-53be0d36a84c // indirect github.com/pierrec/lz4 v2.0.5+incompatible // indirect github.com/pingcap/check v0.0.0-20171206051426-1c287c953996 - github.com/pingcap/errors v0.11.0 // indirect + github.com/pingcap/errors v0.11.0 github.com/pingcap/goleveldb v0.0.0-20161010101021-158edde5a354 // indirect github.com/pingcap/kvproto v0.0.0-20181010074705-0ba3ca8a6e37 // indirect github.com/pingcap/parser v0.0.0-20181210061630-27e9d3e251d4 // indirect github.com/pingcap/pd v2.0.5+incompatible github.com/pingcap/tidb v2.1.0-beta.0.20180823032518-ef6590e1899a+incompatible - github.com/pingcap/tidb-tools v2.1.1-0.20181130053235-0206fdab9ef8+incompatible + github.com/pingcap/tidb-tools v2.1.3-0.20190215110732-23405d82dbe6+incompatible github.com/pingcap/tipb v0.0.0-20180711115030-4141907f6909 - github.com/pkg/errors v0.8.0 // indirect + github.com/pkg/errors v0.8.0 github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v0.8.0 github.com/prometheus/client_model v0.0.0-20150212101744-fa8ad6fec335 // indirect @@ -79,7 +80,7 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20180503174638-e2704e165165 github.com/samuel/go-zookeeper v0.0.0-20170815201139-e6b59f6144be github.com/siddontang/go v0.0.0-20161005110831-1e9ce2a5ac40 - github.com/sirupsen/logrus v0.0.0-20180830201151-78fa2915c1fa // indirect + github.com/sirupsen/logrus v0.0.0-20180830201151-78fa2915c1fa github.com/soheilhy/cmux v0.1.2 github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 // indirect github.com/stretchr/testify v1.2.2 // indirect @@ -97,7 +98,7 @@ require ( golang.org/x/crypto v0.0.0-20150218234220-1351f936d976 // indirect golang.org/x/lint v0.0.0-20181011164241-5906bd5c48cd // indirect golang.org/x/net v0.0.0-20180724234803-3673e40ba225 - golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect + golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f golang.org/x/sys v0.0.0-20161006025142-8d1157a43547 golang.org/x/time v0.0.0-20170420181420-c06e80d9300e // indirect golang.org/x/tools v0.0.0-20181012201414-c0eb142035b5 // indirect @@ -109,7 +110,7 @@ require ( gopkg.in/fsnotify.v1 v1.4.7 // indirect gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2 // indirect gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce // indirect - gopkg.in/natefinch/lumberjack.v2 v2.0.0-20170531160350-a96e63847dc3 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.0.0-20170531160350-a96e63847dc3 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.0.0-20170407172122-cd8b52f8269e // indirect ) diff --git a/pkg/loader/README.md b/pkg/loader/README.md new file mode 100644 index 000000000..8f2cb5e79 --- /dev/null +++ b/pkg/loader/README.md @@ -0,0 +1,31 @@ +loader +====== + +A package to load data into MySQL in real-time, aimed to be used by *reparo*, *drainer* etc unified. + + +### Getting started +- Example is available via [example_loader_test.go](./example_loader_test.go) + + You need to write a translator to use *Loader* like *SlaveBinlogToTxn* in [translate.go](./translate.go) to translate upstream data format (e.g. binlog) into `Txn` objects. + + +## Overview +Loader splits the upstream transaction DML events and concurrently (shared by primary key or unique key) loads data into MySQL. It respects causality with [causality.go](./causality.go). + + +## Optimization +#### Large Operation +Instead of executing DML one by one, we can combine many small operations into a single large operation, like using INSERT statements with multiple VALUES lists to insert several rows at a time. This is [faster](https://medium.com/@benmorel/high-speed-inserts-with-mysql-9d3dcd76f723) than inserting one by one. + +#### Merge by Primary Key +You may want to read [log-compaction](https://kafka.apache.org/documentation/#compaction) of Kafka. + +We can treat a table with Primary Key like a KV-store. To reload the table with the change history of the table, we only need the last value of every key. + +While synchronizing data into downstream at real-time, we can get DML events from upstream in batchs and merge by key. After merging, there's only one event for each key, so at downstream, we don't need to do as many events as upstream. This also help we to use batch insert operation. + +We should also consider secondary unique key here, see *execTableBatch* in [executor.go](./executor.go). Currently, we only merge by primary key and do batch operation if the table have primary key and no unique key. + + + diff --git a/pkg/loader/bench_test.go b/pkg/loader/bench_test.go new file mode 100644 index 000000000..144688cc1 --- /dev/null +++ b/pkg/loader/bench_test.go @@ -0,0 +1,239 @@ +package loader + +import ( + "database/sql" + "fmt" + "sync" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/juju/errors" + "github.com/ngaut/log" +) + +func getTestDB() (db *sql.DB, err error) { + dsn := "root:@tcp(127.0.0.1:3306)/?charset=utf8&interpolateParams=true&readTimeout=1m&multiStatements=true" + db, err = sql.Open("mysql", dsn) + return +} + +func BenchmarkInsertMerge(b *testing.B) { + benchmarkWrite(b, true) +} + +func BenchmarkInsertNoMerge(b *testing.B) { + benchmarkWrite(b, false) +} + +func BenchmarkUpdateMerge(b *testing.B) { + benchmarkUpdate(b, true) +} + +func BenchmarkUpdateNoMerge(b *testing.B) { + benchmarkUpdate(b, false) +} + +func BenchmarkDeleteMerge(b *testing.B) { + benchmarkDelete(b, true) +} + +func BenchmarkDeleteNoMerge(b *testing.B) { + benchmarkDelete(b, false) +} + +func benchmarkUpdate(b *testing.B, merge bool) { + log.SetLevelByString("error") + + r, err := newRunner(merge) + if err != nil { + b.Fatal(err) + } + + dropTable(r.db, r.loader) + createTable(r.db, r.loader) + + loadTable(r.db, r.loader, b.N) + + b.ResetTimer() + updateTable(r.db, r.loader, b.N) + + r.close() +} + +func benchmarkDelete(b *testing.B, merge bool) { + log.SetLevelByString("error") + + r, err := newRunner(merge) + if err != nil { + b.Fatal(err) + } + + dropTable(r.db, r.loader) + createTable(r.db, r.loader) + + loadTable(r.db, r.loader, b.N) + + b.ResetTimer() + deleteTable(r.db, r.loader, b.N) + + r.close() +} + +func benchmarkWrite(b *testing.B, merge bool) { + log.SetLevelByString("error") + + r, err := newRunner(merge) + if err != nil { + b.Fatal(err) + } + + dropTable(r.db, r.loader) + createTable(r.db, r.loader) + + b.ResetTimer() + loadTable(r.db, r.loader, b.N) + + r.close() +} + +type runner struct { + db *sql.DB + loader *Loader + wg sync.WaitGroup +} + +func newRunner(merge bool) (r *runner, err error) { + db, err := getTestDB() + if err != nil { + return nil, errors.Trace(err) + } + + loader, err := NewLoader(db, WorkerCount(16), BatchSize(128)) + if err != nil { + return nil, errors.Trace(err) + } + + loader.merge = merge + + r = new(runner) + r.db = db + r.loader = loader + + r.wg.Add(1) + go func() { + err := loader.Run() + if err != nil { + log.Fatal(err) + } + r.wg.Done() + }() + + go func() { + for range loader.Successes() { + + } + }() + + return +} + +func (r *runner) close() { + r.loader.Close() + r.wg.Wait() +} + +func createTable(db *sql.DB, loader *Loader) error { + var sql string + + sql = "create table test1(id int primary key, a1 int)" + // sql = "create table test1(id int, a1 int, UNIQUE KEY `id` (`id`))" + loader.Input() <- NewDDLTxn("test", "test1", sql) + + return nil +} + +func dropTable(db *sql.DB, loader *Loader) error { + sql := fmt.Sprintf("drop table if exists test1") + loader.Input() <- NewDDLTxn("test", "test1", sql) + return nil +} + +func loadTable(db *sql.DB, loader *Loader, n int) error { + var txns []*Txn + for i := 0; i < n; i++ { + txn := new(Txn) + dml := &DML{ + Database: "test", + Table: "test1", + Tp: InsertDMLType, + Values: map[string]interface{}{ + "id": i, + "a1": i, + }, + } + + txn.AppendDML(dml) + txns = append(txns, txn) + } + + for _, txn := range txns { + loader.Input() <- txn + } + + return nil +} + +func updateTable(db *sql.DB, loader *Loader, n int) error { + var txns []*Txn + for i := 0; i < n; i++ { + txn := new(Txn) + dml := &DML{ + Database: "test", + Table: "test1", + Tp: UpdateDMLType, + Values: map[string]interface{}{ + "id": i, + "a1": i * 10, + }, + OldValues: map[string]interface{}{ + "id": i, + "a1": i, + }, + } + + txn.AppendDML(dml) + txns = append(txns, txn) + } + + for _, txn := range txns { + loader.Input() <- txn + } + + return nil +} + +func deleteTable(db *sql.DB, loader *Loader, n int) error { + var txns []*Txn + for i := 0; i < n; i++ { + txn := new(Txn) + dml := &DML{ + Database: "test", + Table: "test1", + Tp: DeleteDMLType, + Values: map[string]interface{}{ + "id": i, + "a1": i, + }, + } + + txn.AppendDML(dml) + txns = append(txns, txn) + } + + for _, txn := range txns { + loader.Input() <- txn + } + + return nil + +} diff --git a/drainer/causality.go b/pkg/loader/causality.go similarity index 81% rename from drainer/causality.go rename to pkg/loader/causality.go index 508361ec3..957429121 100644 --- a/drainer/causality.go +++ b/pkg/loader/causality.go @@ -11,30 +11,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -package drainer +package loader import "github.com/juju/errors" -// causality provides a simple mechanism to improve the concurrency of SQLs execution under the premise of ensuring correctness. +// Causality provides a simple mechanism to improve the concurrency of SQLs execution under the premise of ensuring correctness. // causality groups sqls that maybe contain causal relationships, and syncer executes them linearly. // if some conflicts exist in more than one groups, then syncer waits all SQLs that are grouped be executed and reset causality. // this mechanism meets quiescent consistency to ensure correctness. -type causality struct { +type Causality struct { relations map[string]string } -func newCausality() *causality { - return &causality{ +func NewCausality() *Causality { + return &Causality{ relations: make(map[string]string), } } -func (c *causality) add(keys []string) error { +func (c *Causality) Add(keys []string) error { if len(keys) == 0 { return nil } - if c.detectConflict(keys) { + if c.DetectConflict(keys) { return errors.New("some conflicts in causality, must be resolved") } // find causal key @@ -54,16 +54,16 @@ func (c *causality) add(keys []string) error { return nil } -func (c *causality) get(key string) string { +func (c *Causality) Get(key string) string { return c.relations[key] } -func (c *causality) reset() { +func (c *Causality) Reset() { c.relations = make(map[string]string) } -// detectConflict detects whether there is a conflict -func (c *causality) detectConflict(keys []string) bool { +// DetectConflict detects whether there is a conflict +func (c *Causality) DetectConflict(keys []string) bool { if len(keys) == 0 { return false } diff --git a/drainer/causality_test.go b/pkg/loader/causality_test.go similarity index 73% rename from drainer/causality_test.go rename to pkg/loader/causality_test.go index d585378ad..21c7b5947 100644 --- a/drainer/causality_test.go +++ b/pkg/loader/causality_test.go @@ -11,28 +11,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -package drainer +package loader import ( . "github.com/pingcap/check" ) -func (s *testDrainerSuite) TestCausality(c *C) { - ca := newCausality() +type causalitySuite struct{} + +var _ = Suite(&causalitySuite{}) + +func (s *causalitySuite) TestCausality(c *C) { + ca := NewCausality() caseData := []string{"test_1", "test_2", "test_3"} excepted := map[string]string{ "test_1": "test_1", "test_2": "test_1", "test_3": "test_1", } - c.Assert(ca.add(caseData), IsNil) + c.Assert(ca.Add(caseData), IsNil) c.Assert(ca.relations, DeepEquals, excepted) - c.Assert(ca.add([]string{"test_4"}), IsNil) + c.Assert(ca.Add([]string{"test_4"}), IsNil) excepted["test_4"] = "test_4" c.Assert(ca.relations, DeepEquals, excepted) conflictData := []string{"test_4", "test_3"} - c.Assert(ca.detectConflict(conflictData), IsTrue) - c.Assert(ca.add(conflictData), NotNil) - ca.reset() + c.Assert(ca.DetectConflict(conflictData), IsTrue) + c.Assert(ca.Add(conflictData), NotNil) + ca.Reset() c.Assert(ca.relations, HasLen, 0) } diff --git a/pkg/loader/example_loader_test.go b/pkg/loader/example_loader_test.go new file mode 100644 index 000000000..61d280803 --- /dev/null +++ b/pkg/loader/example_loader_test.go @@ -0,0 +1,64 @@ +package loader + +import "log" + +func Example() { + // create sql.DB + db, err := CreateDB("root", "", "localhost", 4000) + if err != nil { + log.Fatal(err) + } + + // init loader + loader, err := NewLoader(db, WorkerCount(16), BatchSize(128)) + if err != nil { + log.Fatal(err) + } + + // get the success txn from loader + go func() { + // the return order will be the order you push into loader.Input() + for txn := range loader.Successes() { + log.Print("succ: ", txn) + } + }() + + // run loader + go func() { + // return non nil if encounter some case fail to load data the downstream + // or nil when loader is closed when all data is loaded to downstream + err := loader.Run() + if err != nil { + log.Fatal(err) + } + }() + + // push ddl txn + loader.Input() <- NewDDLTxn("test", "test", "create table test(id primary key)") + + // push one insert dml txn + values := map[string]interface{}{"id": 1} + loader.Input() <- &Txn{ + DMLs: []*DML{{Database: "test", Table: "test", Tp: InsertDMLType, Values: values}}, + } + + // push one update dml txn + newValues := map[string]interface{}{"id": 2} + loader.Input() <- &Txn{ + DMLs: []*DML{{Database: "test", Table: "test", Tp: UpdateDMLType, Values: newValues, OldValues: values}}, + } + + // you can set safe mode or not at run time + // which use replace for insert event and delete + replace for update make it be idempotent + loader.SetSafeMode(true) + + // push one delete dml txn + loader.Input() <- &Txn{ + DMLs: []*DML{{Database: "test", Table: "test", Tp: DeleteDMLType, Values: newValues}}, + } + //... + + // Close the Loader. No more Txn can be push into Input() + // Run will quit when all data is drained + loader.Close() +} diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go new file mode 100644 index 000000000..94c2782f8 --- /dev/null +++ b/pkg/loader/executor.go @@ -0,0 +1,322 @@ +package loader + +import ( + "context" + gosql "database/sql" + "fmt" + "strings" + "time" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/sync/errgroup" +) + +var defaultBatchSize = 128 + +type executor struct { + db *gosql.DB + batchSize int + queryHistogramVec *prometheus.HistogramVec +} + +func newExecutor(db *gosql.DB) *executor { + exe := &executor{ + db: db, + batchSize: defaultBatchSize, + } + + return exe +} + +func (e *executor) withBatchSize(batchSize int) *executor { + e.batchSize = batchSize + return e +} + +func (e *executor) withQueryHistogramVec(queryHistogramVec *prometheus.HistogramVec) *executor { + e.queryHistogramVec = queryHistogramVec + return e +} + +func groupByTable(dmls []*DML) (tables map[string][]*DML) { + if len(dmls) == 0 { + return nil + } + + tables = make(map[string][]*DML) + for _, dml := range dmls { + table := quoteSchema(dml.Database, dml.Table) + tableDMLs := tables[table] + tableDMLs = append(tableDMLs, dml) + tables[table] = tableDMLs + } + + return +} + +func (e *executor) execTableBatchRetry(dmls []*DML, retryNum int, backoff time.Duration) error { + var err error + for i := 0; i < retryNum; i++ { + if i > 0 { + time.Sleep(backoff) + } + + err = e.execTableBatch(dmls) + if err == nil { + return nil + } + } + return errors.Trace(err) +} + +// a wrap of *sql.Tx with metrics +type tx struct { + *gosql.Tx + queryHistogramVec *prometheus.HistogramVec +} + +// wrap of sql.Tx.Exec() +func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { + start := time.Now() + res, err := tx.Tx.Exec(query, args...) + if tx.queryHistogramVec != nil { + tx.queryHistogramVec.WithLabelValues("exec").Observe(time.Since(start).Seconds()) + } + + return res, err +} + +// wrap of sql.Tx.Commit() +func (tx *tx) commit() error { + start := time.Now() + err := tx.Tx.Commit() + if tx.queryHistogramVec != nil { + tx.queryHistogramVec.WithLabelValues("commit").Observe(time.Since(start).Seconds()) + } + + return errors.Trace(err) +} + +// return a wrap of sql.Tx +func (s *executor) begin() (*tx, error) { + sqlTx, err := s.db.Begin() + if err != nil { + return nil, errors.Trace(err) + } + + return &tx{ + Tx: sqlTx, + queryHistogramVec: s.queryHistogramVec, + }, nil +} + +func (e *executor) bulkDelete(deletes []*DML) error { + var sqls strings.Builder + var argss []interface{} + + for _, dml := range deletes { + sql, args := dml.sql() + sqls.WriteString(sql) + sqls.WriteByte(';') + argss = append(argss, args...) + } + tx, err := e.begin() + if err != nil { + return errors.Trace(err) + } + sql := sqls.String() + _, err = tx.exec(sql, argss...) + if err != nil { + log.Error("exec fail sql: %s, args: %v", sql, argss) + tx.Rollback() + return errors.Trace(err) + } + + err = tx.commit() + if err != nil { + return errors.Trace(err) + } + return nil +} + +func (e *executor) bulkReplace(inserts []*DML) error { + if len(inserts) == 0 { + return nil + } + + info := inserts[0].info + dbName := inserts[0].Database + tableName := inserts[0].Table + + builder := new(strings.Builder) + + fmt.Fprintf(builder, "REPLACE INTO %s(%s) VALUES ", quoteSchema(dbName, tableName), buildColumnList(info.columns)) + + holder := fmt.Sprintf("(%s)", holderString(len(info.columns))) + for i := 0; i < len(inserts); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(holder) + } + + var args []interface{} + for _, insert := range inserts { + for _, name := range info.columns { + v := insert.Values[name] + args = append(args, v) + } + } + + tx, err := e.begin() + if err != nil { + return errors.Trace(err) + } + _, err = tx.exec(builder.String(), args...) + if err != nil { + log.Errorf("exec fail sql: %s, args: %v, err: %v", builder.String(), args, err) + tx.Rollback() + return errors.Trace(err) + } + err = tx.commit() + if err != nil { + return errors.Trace(err) + } + return nil + +} + +// we merge dmls by primary key, after merge by key, we +// have only one dml for one primary key which contains the newest value(like a kv store), +// to avoid other column's duplicate entry, we should apply delete dmls first, then insert&update +// use replace to handle the update unique index case(see https://github.com/pingcap/tidb-binlog/pull/437/files) +// or we can simply check if it update unique index column or not, and for update change to (delete + insert) +// the final result should has no duplicate entry or the origin dmls is wrong. +func (e *executor) execTableBatch(dmls []*DML) error { + if len(dmls) == 0 { + return nil + } + + types, err := mergeByPrimaryKey(dmls) + if err != nil { + return errors.Trace(err) + } + + log.Debugf("dmls: %v after merge: %v", dmls, types) + + if allDeletes, ok := types[DeleteDMLType]; ok { + err := e.splitExecDML(allDeletes, e.bulkDelete) + if err != nil { + return errors.Trace(err) + } + } + + if allInserts, ok := types[InsertDMLType]; ok { + err := e.splitExecDML(allInserts, e.bulkReplace) + if err != nil { + return errors.Trace(err) + } + } + + if allUpdates, ok := types[UpdateDMLType]; ok { + err := e.splitExecDML(allUpdates, e.bulkReplace) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +// splitExecDML split dmls to size of e.batchSize and call exec concurrently +func (e *executor) splitExecDML(dmls []*DML, exec func(dmls []*DML) error) error { + errg, _ := errgroup.WithContext(context.Background()) + + for _, split := range splitDMLs(dmls, e.batchSize) { + split := split + errg.Go(func() error { + err := exec(split) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + + return errors.Trace(errg.Wait()) +} + +func (e *executor) singleExecRetry(allDMLs []*DML, safeMode bool, retryNum int, backoff time.Duration) error { + var err error + + for _, dmls := range splitDMLs(allDMLs, e.batchSize) { + var i int + for i = 0; i < retryNum; i++ { + if i > 0 { + time.Sleep(backoff) + } + + err = e.singleExec(dmls, safeMode) + if err == nil { + break + } + } + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (e *executor) singleExec(dmls []*DML, safeMode bool) error { + tx, err := e.begin() + if err != nil { + return errors.Trace(err) + } + + for _, dml := range dmls { + if safeMode && dml.Tp == UpdateDMLType { + sql, args := dml.deleteSQL() + log.Debugf("exec: %s, args: %v", sql, args) + _, err := tx.exec(sql, args...) + if err != nil { + log.Errorf("err: %v, exec dml sql: %s, args: %v", err, sql, args) + tx.Rollback() + return errors.Trace(err) + } + + sql, args = dml.replaceSQL() + log.Debugf("exec: %s, args: %v", sql, args) + _, err = tx.exec(sql, args...) + if err != nil { + log.Errorf("err: %v, exec dml sql: %s, args: %v", err, sql, args) + tx.Rollback() + return errors.Trace(err) + } + } else if safeMode && dml.Tp == InsertDMLType { + sql, args := dml.replaceSQL() + log.Debugf("exec dml sql: %s, args: %v", sql, args) + _, err := tx.exec(sql, args...) + if err != nil { + log.Errorf("err: %v, exec dml sql: %s, args: %v", err, sql, args) + tx.Rollback() + return errors.Trace(err) + } + } else { + sql, args := dml.sql() + log.Debugf("exec dml sql: %s, args: %v", sql, args) + _, err := tx.exec(sql, args...) + if err != nil { + log.Errorf("err: %v, exec dml sql: %s, args: %v", err, sql, args) + tx.Rollback() + return errors.Trace(err) + } + } + } + + err = tx.commit() + return errors.Trace(err) +} diff --git a/pkg/loader/load.go b/pkg/loader/load.go new file mode 100644 index 000000000..58d7e2d5e --- /dev/null +++ b/pkg/loader/load.go @@ -0,0 +1,468 @@ +package loader + +import ( + "context" + gosql "database/sql" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/sync/errgroup" + + pkgsql "github.com/pingcap/tidb-binlog/pkg/sql" +) + +const ( + maxDMLRetryCount = 100 + maxDDLRetryCount = 5 + + execLimitMultiple = 3 +) + +// Loader is used to load data to mysql +type Loader struct { + // we can get table info from downstream db + // like column name, pk & uk + db *gosql.DB + + tableInfos sync.Map + + batchSize int + workerCount int + + input chan *Txn + successTxn chan *Txn + + metrics *MetricsGroup + + // change update -> delete + replace + // insert -> replace + safeMode int32 + + // always true now + // merge the same primary key DML sequence, then batch insert + merge bool +} + +// MetricsGroup contains metrics of Loader +type MetricsGroup struct { + EventCounterVec *prometheus.CounterVec + QueryHistogramVec *prometheus.HistogramVec +} + +type options struct { + workerCount int + batchSize int + metrics *MetricsGroup +} + +var defaultLoaderOptions = options{ + workerCount: 16, + batchSize: 20, + metrics: nil, +} + +// A LoaderOption sets options such batch size, worker count etc. +type LoaderOption func(*options) + +// WorkerCount set worker count of loader +func WorkerCount(n int) LoaderOption { + return func(o *options) { + o.workerCount = n + } +} + +// BatchSize set batch size of loader +func BatchSize(n int) LoaderOption { + return func(o *options) { + o.batchSize = n + } +} + +// Metrics set metrics of loader +func Metrics(m *MetricsGroup) LoaderOption { + return func(o *options) { + o.metrics = m + } +} + +// NewLoader return a Loader +// db must support multi statement and interpolateParams +func NewLoader(db *gosql.DB, opt ...LoaderOption) (*Loader, error) { + opts := defaultLoaderOptions + for _, o := range opt { + o(&opts) + } + + s := &Loader{ + db: db, + workerCount: opts.workerCount, + batchSize: opts.batchSize, + metrics: opts.metrics, + input: make(chan *Txn, 1024), + successTxn: make(chan *Txn, 1024), + merge: true, + } + + db.SetMaxOpenConns(opts.workerCount) + + return s, nil +} + +func (s *Loader) metricsInputTxn(txn *Txn) { + if s.metrics == nil { + return + } + + s.metrics.EventCounterVec.WithLabelValues("Txn").Inc() + + if txn.isDDL() { + s.metrics.EventCounterVec.WithLabelValues("DDL").Add(1) + } else { + var insertEvent float64 + var deleteEvent float64 + var updateEvent float64 + for _, dml := range txn.DMLs { + switch dml.Tp { + case InsertDMLType: + insertEvent++ + case UpdateDMLType: + updateEvent++ + case DeleteDMLType: + deleteEvent++ + } + } + s.metrics.EventCounterVec.WithLabelValues("Insert").Add(insertEvent) + s.metrics.EventCounterVec.WithLabelValues("Update").Add(updateEvent) + s.metrics.EventCounterVec.WithLabelValues("Delete").Add(deleteEvent) + } +} + +// SetSafeMode set safe mode +func (s *Loader) SetSafeMode(safe bool) { + if safe { + atomic.StoreInt32(&s.safeMode, 1) + } else { + atomic.StoreInt32(&s.safeMode, 0) + } +} + +// GetSafeMode get safe mode +func (s *Loader) GetSafeMode() bool { + v := atomic.LoadInt32(&s.safeMode) + + return v != 0 +} + +func (s *Loader) markSuccess(txns ...*Txn) { + for _, txn := range txns { + s.successTxn <- txn + } + log.Debugf("markSuccess %d txns", len(txns)) +} + +// Input returns input channel which used to put Txn into Loader +func (s *Loader) Input() chan<- *Txn { + return s.input +} + +// Successes return a channel to get the successfully Txn loaded to mysql +func (s *Loader) Successes() <-chan *Txn { + return s.successTxn +} + +// Close close the Loader, no more Txn can be push into Input() +// Run will quit when all data is drained +func (s *Loader) Close() { + close(s.input) +} + +func (s *Loader) refreshTableInfo(schema string, table string) (info *tableInfo, err error) { + info, err = getTableInfo(s.db, schema, table) + if err != nil { + return info, errors.Trace(err) + } + + if len(info.uniqueKeys) == 0 { + log.Warnf("table %s has no any primary key and unique index, it may be slow when syncing data to downstream, we highly recommend add primary key or unique key for table", quoteSchema(schema, table)) + } + + s.tableInfos.Store(quoteSchema(schema, table), info) + + return +} + +func (s *Loader) getTableInfo(schema string, table string) (info *tableInfo, err error) { + v, ok := s.tableInfos.Load(quoteSchema(schema, table)) + if ok { + info = v.(*tableInfo) + return + } + + return s.refreshTableInfo(schema, table) +} + +func (s *Loader) execDDL(ddl *DDL) error { + log.Debug("exec ddl: ", ddl) + var err error + var tx *gosql.Tx + for i := 0; i < maxDDLRetryCount; i++ { + if i > 0 { + time.Sleep(time.Second) + } + + tx, err = s.db.Begin() + if err != nil { + log.Error(err) + continue + } + + if len(ddl.Database) > 0 { + _, err = tx.Exec(fmt.Sprintf("use %s;", quoteName(ddl.Database))) + if err != nil { + log.Error(err) + tx.Rollback() + continue + } + } + + log.Infof("retry num: %d, exec ddl: %s", i, ddl.SQL) + _, err = tx.Exec(ddl.SQL) + if err != nil { + log.Error(err) + tx.Rollback() + continue + } + + err = tx.Commit() + if err != nil { + log.Error(err) + continue + } + + log.Info("exec ddl success: ", ddl.SQL) + return nil + } + + return errors.Trace(err) +} + +func (s *Loader) execByHash(executor *executor, byHash [][]*DML) error { + errg, _ := errgroup.WithContext(context.Background()) + + for _, dmls := range byHash { + if len(dmls) == 0 { + continue + } + + dmls := dmls + + errg.Go(func() error { + err := executor.singleExecRetry(dmls, s.GetSafeMode(), maxDMLRetryCount, time.Second) + return err + }) + } + + err := errg.Wait() + + return errors.Trace(err) +} + +func (s *Loader) singleExec(executor *executor, dmls []*DML) error { + causality := NewCausality() + + var byHash = make([][]*DML, s.workerCount) + + for _, dml := range dmls { + keys := getKeys(dml) + log.Debugf("dml: %v keys: %v", dml, keys) + conflict := causality.DetectConflict(keys) + if conflict { + log.Infof("meet causality.DetectConflict exec now table: %v, keys: %v", + quoteSchema(dml.Database, dml.Table), keys) + err := s.execByHash(executor, byHash) + if err != nil { + return errors.Trace(err) + } + + causality.Reset() + for i := 0; i < len(byHash); i++ { + byHash[i] = byHash[i][:0] + } + } + + causality.Add(keys) + key := causality.Get(keys[0]) + idx := int(genHashKey(key)) % len(byHash) + byHash[idx] = append(byHash[idx], dml) + + } + + err := s.execByHash(executor, byHash) + return errors.Trace(err) +} + +func (s *Loader) execDMLs(dmls []*DML) error { + if len(dmls) == 0 { + return nil + } + + for _, dml := range dmls { + var err error + dml.info, err = s.getTableInfo(dml.Database, dml.Table) + if err != nil { + return errors.Trace(err) + } + } + + tables := groupByTable(dmls) + + batchTables := make(map[string][]*DML) + var singleDMLs []*DML + + for tableName, tableDMLs := range tables { + info := tableDMLs[0].info + if info.primaryKey != nil && len(info.uniqueKeys) == 0 && s.merge { + batchTables[tableName] = tableDMLs + } else { + singleDMLs = append(singleDMLs, tableDMLs...) + } + } + + log.Debugf("exec by tables: %d tables, by single: %d dmls", len(batchTables), len(singleDMLs)) + + errg, _ := errgroup.WithContext(context.Background()) + executor := newExecutor(s.db).withBatchSize(s.batchSize) + if s.metrics != nil { + executor = executor.withQueryHistogramVec(s.metrics.QueryHistogramVec) + } + + for _, dmls := range batchTables { + // https://golang.org/doc/faq#closures_and_goroutines + dmls := dmls + errg.Go(func() error { + err := executor.execTableBatchRetry(dmls, maxDMLRetryCount, time.Second) + return err + }) + } + + errg.Go(func() error { + err := s.singleExec(executor, singleDMLs) + return errors.Trace(err) + }) + + err := errg.Wait() + + return errors.Trace(err) +} + +// Run will quit when meet any error, or all the txn are drained +func (s *Loader) Run() error { + defer func() { + log.Info("Run()... in Loader quit") + close(s.successTxn) + }() + + var err error + + // the txns and according dmls we accumulate to execute later + var txns []*Txn + var dmls []*DML + + execDML := func() error { + err := s.execDMLs(dmls) + if err != nil { + return errors.Trace(err) + } + + s.markSuccess(txns...) + txns = txns[:0] + dmls = dmls[:0] + return nil + } + + execDDL := func(txn *Txn) error { + err := s.execDDL(txn.DDL) + if err != nil { + if !pkgsql.IgnoreDDLError(err) { + log.Errorf("exe ddl: %s fail: %v", txn.DDL.SQL, err) + return errors.Trace(err) + } + log.Warnf("ignore ddl error: %v, ddl: %v", err, txn.DDL) + } + + s.markSuccess(txn) + s.refreshTableInfo(txn.DDL.Database, txn.DDL.Table) + return nil + } + + handleTxn := func(txn *Txn) error { + s.metricsInputTxn(txn) + + // we always executor the previous dmls when we meet ddl, + // and executor ddl one by one. + if txn.isDDL() { + if err = execDML(); err != nil { + return errors.Trace(err) + } + + err = execDDL(txn) + if err != nil { + return errors.Trace(err) + } + } else { + dmls = append(dmls, txn.DMLs...) + txns = append(txns, txn) + + // reach a limit size to exec + if len(dmls) >= s.batchSize*s.workerCount*execLimitMultiple { + if err = execDML(); err != nil { + return errors.Trace(err) + } + } + } + + return nil + } + + for { + select { + case txn, ok := <-s.input: + if !ok { + log.Info("loader closed quit running") + if err = execDML(); err != nil { + return errors.Trace(err) + } + return nil + } + + if err = handleTxn(txn); err != nil { + return errors.Trace(err) + } + + default: + // excute dmls ASAP if no more txn we can get + if len(dmls) > 0 { + if err = execDML(); err != nil { + return errors.Trace(err) + } + + continue + } + + // get first + txn, ok := <-s.input + if !ok { + return nil + } + + if err = handleTxn(txn); err != nil { + return errors.Trace(err) + } + } + } +} diff --git a/pkg/loader/load_test.go b/pkg/loader/load_test.go new file mode 100644 index 000000000..d6f0564fc --- /dev/null +++ b/pkg/loader/load_test.go @@ -0,0 +1,27 @@ +package loader + +import ( + sqlmock "github.com/DATA-DOG/go-sqlmock" + check "github.com/pingcap/check" +) + +type LoadSuite struct { +} + +var _ = check.Suite(&LoadSuite{}) + +func (cs *LoadSuite) SetUpTest(c *check.C) { +} + +func (cs *LoadSuite) TearDownTest(c *check.C) { +} + +func (cs *LoadSuite) TestNewClose(c *check.C) { + db, _, err := sqlmock.New() + c.Assert(err, check.IsNil) + + loader, err := NewLoader(db) + c.Assert(err, check.IsNil) + + loader.Close() +} diff --git a/pkg/loader/merge.go b/pkg/loader/merge.go new file mode 100644 index 000000000..6a8c946a2 --- /dev/null +++ b/pkg/loader/merge.go @@ -0,0 +1,115 @@ +package loader + +import ( + "github.com/ngaut/log" + "github.com/pkg/errors" +) + +// all DML must be the same table +// we merge consequence DML by Primary key +// after merge, only have one record for one key +// insert + delete -> delete +// insert + update -> insert +// insert + insert -> insert invalid +// delete + delete -> delete invalid +// delete + update -> - invalid +// delete + insert -> insert +// update + delete -> delete +// update + update -> update +// update + insert -> - invalid +func mergeByPrimaryKey(dmls []*DML) (types map[DMLType][]*DML, err error) { + if len(dmls) == 0 { + return + } + + pks := dmls[0].primaryKeys() + if len(pks) == 0 { + return nil, errors.Errorf("%s.%s no pk", dmls[0].Database, dmls[0].Table) + } + + var res = make(map[string]*DML) + + // if update primary key, replace update -> delete(old one) + insert(new one) + var tmpDmls []*DML + for _, dml := range dmls { + if dml.Tp == UpdateDMLType && dml.updateKey() { + deleteDML := &DML{ + Database: dml.Database, + Table: dml.Table, + Tp: DeleteDMLType, + Values: dml.OldValues, + info: dml.info, + } + tmpDmls = append(tmpDmls, deleteDML) + + insertDML := &DML{ + Database: dml.Database, + Table: dml.Table, + Tp: InsertDMLType, + Values: dml.Values, + OldValues: nil, + info: dml.info, + } + tmpDmls = append(tmpDmls, insertDML) + } else { + tmpDML := &DML{ + Database: dml.Database, + Table: dml.Table, + Tp: dml.Tp, + Values: dml.Values, + OldValues: dml.OldValues, + info: dml.info, + } + + tmpDmls = append(tmpDmls, tmpDML) + } + } + dmls = tmpDmls + + for _, dml := range dmls { + key := dml.formatKey() + oldDML, ok := res[key] + if !ok { + res[key] = dml + continue + } + + switch dml.Tp { + case InsertDMLType: + // ignore the previous delete + if oldDML.Tp == DeleteDMLType { + } else if oldDML.Tp == UpdateDMLType || oldDML.Tp == InsertDMLType { + log.Warnf("update-insert/insert-insert happen. before: %+v, after: %+v", oldDML, dml) + } + res[key] = dml + case DeleteDMLType: + // insert/update + delete -> delete + res[key] = dml + case UpdateDMLType: + if oldDML.Tp == InsertDMLType { + // insert-update -> insert + dml.Tp = InsertDMLType + dml.OldValues = nil + } else if oldDML.Tp == UpdateDMLType { + // update-update -> update + dml.OldValues = oldDML.OldValues + } else if oldDML.Tp == DeleteDMLType { + // delete + update -> invalid + log.Warn("abnormal case delete + update, just remain update now") + } + res[key] = dml + + default: + return nil, errors.Errorf("unknown tp: %v", dml.Tp) + } + } + + types = make(map[DMLType][]*DML) + for _, dml := range res { + dmls = types[dml.Tp] + dmls = append(dmls, dml) + types[dml.Tp] = dmls + } + + return +} diff --git a/pkg/loader/merge_test.go b/pkg/loader/merge_test.go new file mode 100644 index 000000000..06d6e98bf --- /dev/null +++ b/pkg/loader/merge_test.go @@ -0,0 +1,145 @@ +package loader + +import ( + "math/rand" + + "github.com/ngaut/log" + check "github.com/pingcap/check" +) + +type modelSuite struct { +} + +var _ = check.Suite(&modelSuite{}) + +func (m *modelSuite) TestMerge(c *check.C) { + log.SetLevelByString("error") + info := &tableInfo{ + columns: []string{"k", "v"}, + uniqueKeys: []indexInfo{{"PRIMARY", []string{"k"}}}, + } + info.primaryKey = &info.uniqueKeys[0] + + apply := func(kv map[int]int, dmls []*DML) map[int]int { + for _, dml := range dmls { + switch dml.Tp { + case InsertDMLType: + k := dml.Values["k"].(int) + v := dml.Values["v"].(int) + kv[k] = v + case UpdateDMLType: + k := dml.Values["k"].(int) + v := dml.Values["v"].(int) + + oldk := dml.OldValues["k"].(int) + // oldv := dml.OldValues["v"].(int) + delete(kv, oldk) + kv[k] = v + case DeleteDMLType: + k := dml.Values["k"].(int) + delete(kv, k) + } + } + + return kv + } + + // generate dmlNum DML date + var dmls []*DML + dmlNum := 100000 + maxKey := 1000 + updateKeyProbability := 0.1 + + var kv = make(map[int]int) + for i := 0; i < dmlNum; i++ { + dml := new(DML) + dml.info = info + dmls = append(dmls, dml) + + k := rand.Intn(maxKey) + v, ok := kv[k] + if !ok { + // insert + dml.Tp = InsertDMLType + dml.Values = make(map[string]interface{}) + dml.Values["k"] = k + dml.Values["v"] = rand.Int() + } else { + if rand.Int()%2 == 0 { + // update + dml.Tp = UpdateDMLType + dml.OldValues = make(map[string]interface{}) + dml.OldValues["k"] = k + dml.OldValues["v"] = v + + newv := rand.Int() + dml.Values = make(map[string]interface{}) + dml.Values["k"] = k + dml.Values["v"] = newv + // check whether to update k + if rand.Float64() < updateKeyProbability { + for try := 0; try < 10; try++ { + newk := rand.Intn(maxKey) + if _, ok := kv[newk]; !ok { + dml.Values["k"] = newk + break + } + } + } + } else { + // delete + dml.Tp = DeleteDMLType + dml.Values = make(map[string]interface{}) + dml.Values["k"] = k + dml.Values["v"] = v + } + } + + kv = apply(kv, []*DML{dml}) + } + + kv = make(map[int]int) + kvMerge := make(map[int]int) + + step := dmlNum / 10 + for i := 0; i < len(dmls); i += step { + end := i + step + if end > len(dmls) { + end = len(dmls) + } + logDMLs(dmls[i:end], c) + kv = apply(kv, dmls[i:end]) + + res, err := mergeByPrimaryKey(dmls[i:end]) + c.Assert(err, check.IsNil) + + noMergeNumber := end - i + mergeNumber := 0 + if mdmls, ok := res[DeleteDMLType]; ok { + logDMLs(mdmls, c) + kvMerge = apply(kvMerge, mdmls) + mergeNumber += len(mdmls) + } + if mdmls, ok := res[InsertDMLType]; ok { + logDMLs(mdmls, c) + kvMerge = apply(kvMerge, mdmls) + c.Logf("kvMerge: %v", kvMerge) + mergeNumber += len(mdmls) + } + if mdmls, ok := res[UpdateDMLType]; ok { + logDMLs(mdmls, c) + kvMerge = apply(kvMerge, mdmls) + mergeNumber += len(mdmls) + } + c.Logf("before number: %d, after merge: %d", noMergeNumber, mergeNumber) + c.Logf("kv: %v kvMerge: %v", kv, kvMerge) + c.Assert(kvMerge, check.DeepEquals, kv) + } +} + +func logDMLs(dmls []*DML, c *check.C) { + c.Log("dmls: ", len(dmls)) + for _, dml := range dmls { + c.Logf("tp: %v, values: %v, OldValues: %v", dml.Tp, dml.Values, dml.OldValues) + } +} diff --git a/pkg/loader/model.go b/pkg/loader/model.go new file mode 100644 index 000000000..a23053a19 --- /dev/null +++ b/pkg/loader/model.go @@ -0,0 +1,324 @@ +package loader + +import ( + "fmt" + "strconv" + "strings" + + "github.com/ngaut/log" +) + +// DMLType represents the dml type +type DMLType int + +// DMLType types +const ( + UnknownDMLType DMLType = 0 + InsertDMLType DMLType = 1 + UpdateDMLType DMLType = 2 + DeleteDMLType DMLType = 3 +) + +// DML holds the dml info +type DML struct { + Database string + Table string + + Tp DMLType + // only set when Tp = UpdateDMLType + OldValues map[string]interface{} + Values map[string]interface{} + + info *tableInfo +} + +// DDL holds the ddl info +type DDL struct { + Database string + Table string + SQL string +} + +// Txn holds transaction info, an DDL or DML sequences +type Txn struct { + DMLs []*DML + DDL *DDL + + // This field is used to hold arbitrary data you wish to include so it + // will be available when receiving on the Successes channel + Metadata interface{} +} + +// AppendDML append a dml +func (t *Txn) AppendDML(dml *DML) { + t.DMLs = append(t.DMLs, dml) +} + +// NewDDLTxn return a Txn +func NewDDLTxn(db string, table string, sql string) *Txn { + txn := new(Txn) + txn.DDL = &DDL{ + Database: db, + Table: table, + SQL: sql, + } + + return txn +} + +func (t *Txn) String() string { + if t.isDDL() { + return fmt.Sprintf("{ddl: %s}", t.DDL.SQL) + } + + return fmt.Sprintf("dml: %v", t.DMLs) +} + +func (t *Txn) isDDL() bool { + return t.DDL != nil +} + +func (dml *DML) primaryKeys() []string { + if dml.info.primaryKey == nil { + return nil + } + + return dml.info.primaryKey.columns +} + +func (dml *DML) primaryKeyValues() []interface{} { + names := dml.primaryKeys() + + var values []interface{} + for _, name := range names { + v := dml.Values[name] + values = append(values, v) + } + + return values +} + +func (dml *DML) formatKey() string { + return formatKey(dml.primaryKeyValues()) +} + +func (dml *DML) formatOldKey() string { + return formatKey(dml.oldPrimaryKeyValues()) +} + +func (dml *DML) updateKey() bool { + if len(dml.OldValues) == 0 { + return false + } + + values := dml.primaryKeyValues() + oldValues := dml.oldPrimaryKeyValues() + + for i := 0; i < len(values); i++ { + if values[i] != oldValues[i] { + return true + } + } + + return false +} + +func (dml *DML) String() string { + return fmt.Sprintf("{db: %s, table: %s,tp: %v values: %d old_values: %d}", + dml.Database, dml.Table, dml.Tp, len(dml.Values), len(dml.OldValues)) +} + +func (dml *DML) oldPrimaryKeyValues() []interface{} { + if len(dml.OldValues) == 0 { + return dml.primaryKeyValues() + } + + names := dml.primaryKeys() + + var values []interface{} + for _, name := range names { + v := dml.OldValues[name] + values = append(values, v) + } + + return values +} + +func (dml *DML) updateSQL() (sql string, args []interface{}) { + builder := new(strings.Builder) + + fmt.Fprintf(builder, "UPDATE %s SET ", quoteSchema(dml.Database, dml.Table)) + + for name, arg := range dml.Values { + if len(args) > 0 { + builder.WriteByte(',') + } + fmt.Fprintf(builder, "%s = ?", quoteName(name)) + args = append(args, arg) + } + + builder.WriteString(" WHERE ") + + whereArgs := dml.buildWhere(builder) + args = append(args, whereArgs...) + + builder.WriteString(" LIMIT 1") + sql = builder.String() + return +} + +func (dml *DML) buildWhere(builder *strings.Builder) (args []interface{}) { + wnames, wargs := dml.whereSlice() + for i := 0; i < len(wnames); i++ { + if i > 0 { + builder.WriteString(" AND ") + } + if wargs[i] == nil { + builder.WriteString(quoteName(wnames[i]) + " IS NULL") + } else { + builder.WriteString(quoteName(wnames[i]) + " = ? ") + args = append(args, wargs[i]) + } + } + return +} + +func (dml *DML) whereValues(names []string) (values []interface{}) { + valueMap := dml.Values + if dml.Tp == UpdateDMLType { + valueMap = dml.OldValues + } + + for _, name := range names { + v := valueMap[name] + values = append(values, v) + } + return +} + +func (dml *DML) whereSlice() (colNames []string, args []interface{}) { + for _, index := range dml.info.uniqueKeys { + values := dml.whereValues(index.columns) + var i int + for i = 0; i < len(values); i++ { + if values[i] == nil { + break + } + } + if i == len(values) { + return index.columns, values + } + } + + return dml.info.columns, dml.whereValues(dml.info.columns) + +} + +func (dml *DML) deleteSQL() (sql string, args []interface{}) { + builder := new(strings.Builder) + + fmt.Fprintf(builder, "DELETE FROM %s WHERE ", quoteSchema(dml.Database, dml.Table)) + args = dml.buildWhere(builder) + builder.WriteString(" LIMIT 1") + + sql = builder.String() + return +} + +func (dml *DML) replaceSQL() (sql string, args []interface{}) { + info := dml.info + sql = fmt.Sprintf("REPLACE INTO %s(%s) VALUES(%s)", quoteSchema(dml.Database, dml.Table), buildColumnList(info.columns), holderString(len(info.columns))) + for _, name := range info.columns { + v := dml.Values[name] + args = append(args, v) + } + return +} + +func (dml *DML) insertSQL() (sql string, args []interface{}) { + sql, args = dml.replaceSQL() + sql = strings.Replace(sql, "REPLACE", "INSERT", 1) + return +} + +func (dml *DML) sql() (sql string, args []interface{}) { + switch dml.Tp { + case InsertDMLType: + return dml.insertSQL() + case UpdateDMLType: + return dml.updateSQL() + case DeleteDMLType: + return dml.deleteSQL() + } + + log.Debugf("dml: %+v sql: %s, args: %v", dml, sql, args) + + return +} + +func formatKey(values []interface{}) string { + builder := new(strings.Builder) + for i, v := range values { + if i != 0 { + builder.WriteString("--") + } + fmt.Fprintf(builder, "%v", v) + } + + return builder.String() +} + +func getKey(names []string, values map[string]interface{}) string { + builder := new(strings.Builder) + for _, name := range names { + v := values[name] + if v == nil { + continue + } + + fmt.Fprintf(builder, "(%s: %v)", name, v) + } + + return builder.String() +} + +func getKeys(dml *DML) (keys []string) { + info := dml.info + + tableName := quoteSchema(dml.Database, dml.Table) + + var addOldKey int + var addNewKey int + + for _, index := range info.uniqueKeys { + key := getKey(index.columns, dml.Values) + if len(key) > 0 { + addNewKey++ + keys = append(keys, key+tableName) + } + } + + if dml.Tp == UpdateDMLType { + for _, index := range info.uniqueKeys { + key := getKey(index.columns, dml.OldValues) + if len(key) > 0 { + addOldKey++ + keys = append(keys, key+tableName) + } + } + } + + if addNewKey == 0 { + key := getKey(info.columns, dml.Values) + tableName + key = strconv.Itoa(int(genHashKey(key))) + keys = append(keys, key) + } + + if dml.Tp == UpdateDMLType && addOldKey == 0 { + key := getKey(info.columns, dml.OldValues) + tableName + key = strconv.Itoa(int(genHashKey(key))) + keys = append(keys, key) + } + + return +} diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go new file mode 100644 index 000000000..b908a3471 --- /dev/null +++ b/pkg/loader/model_test.go @@ -0,0 +1,83 @@ +package loader + +import ( + "strings" + + check "github.com/pingcap/check" +) + +type dmlSuite struct { +} + +var _ = check.Suite(&dmlSuite{}) + +func getDML(key bool, tp DMLType) *DML { + info := &tableInfo{ + columns: []string{"id", "a1"}, + } + + if key { + info.uniqueKeys = append(info.uniqueKeys, indexInfo{"PRIMARY", []string{"id"}}) + } + + dml := new(DML) + dml.info = info + dml.Database = "test" + dml.Table = "test" + dml.Tp = tp + + return dml +} + +func (d *dmlSuite) TestWhere(c *check.C) { + d.testWhere(c, InsertDMLType) + d.testWhere(c, UpdateDMLType) + d.testWhere(c, DeleteDMLType) +} + +func (d *dmlSuite) testWhere(c *check.C, tp DMLType) { + dml := getDML(true, tp) + var values = map[string]interface{}{ + "id": 1, + "a1": 1, + } + + if tp == UpdateDMLType { + dml.OldValues = values + } else { + dml.Values = values + } + + names, args := dml.whereSlice() + c.Assert(names, check.DeepEquals, []string{"id"}) + c.Assert(args, check.DeepEquals, []interface{}{1}) + + builder := new(strings.Builder) + args = dml.buildWhere(builder) + c.Assert(args, check.DeepEquals, []interface{}{1}) + c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) + + // no pk + dml = getDML(false, tp) + if tp == UpdateDMLType { + dml.OldValues = values + } else { + dml.Values = values + } + + names, args = dml.whereSlice() + c.Assert(names, check.DeepEquals, []string{"id", "a1"}) + c.Assert(args, check.DeepEquals, []interface{}{1, 1}) + + builder.Reset() + args = dml.buildWhere(builder) + c.Assert(args, check.DeepEquals, []interface{}{1, 1}) + c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) + + // set a1 to NULL value + values["a1"] = nil + builder.Reset() + args = dml.buildWhere(builder) + c.Assert(args, check.DeepEquals, []interface{}{1}) + c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) +} diff --git a/pkg/loader/translate.go b/pkg/loader/translate.go new file mode 100644 index 000000000..b2714db45 --- /dev/null +++ b/pkg/loader/translate.go @@ -0,0 +1,85 @@ +package loader + +import ( + pb "github.com/pingcap/tidb-tools/tidb-binlog/slave_binlog_proto/go-binlog" +) + +// SlaveBinlogToTxn translate the Binlog format into Txn +func SlaveBinlogToTxn(binlog *pb.Binlog) (txn *Txn) { + txn = new(Txn) + switch binlog.Type { + case pb.BinlogType_DDL: + data := binlog.DdlData + txn.DDL = new(DDL) + txn.DDL.Database = data.GetSchemaName() + txn.DDL.Table = data.GetTableName() + txn.DDL.SQL = string(data.GetDdlQuery()) + case pb.BinlogType_DML: + for _, table := range binlog.DmlData.GetTables() { + for _, mut := range table.GetMutations() { + dml := new(DML) + txn.DMLs = append(txn.DMLs, dml) + dml.Database = table.GetSchemaName() + dml.Table = table.GetTableName() + switch mut.GetType() { + case pb.MutationType_Insert: + dml.Tp = InsertDMLType + case pb.MutationType_Update: + dml.Tp = UpdateDMLType + case pb.MutationType_Delete: + dml.Tp = DeleteDMLType + } + + // setup values + dml.Values = make(map[string]interface{}) + for i, col := range mut.Row.GetColumns() { + name := table.ColumnInfo[i].Name + arg := columnToArg(table.ColumnInfo[i].GetMysqlType(), col) + dml.Values[name] = arg + } + + // setup old values + if dml.Tp == UpdateDMLType { + dml.OldValues = make(map[string]interface{}) + for i, col := range mut.ChangeRow.GetColumns() { + name := table.ColumnInfo[i].Name + arg := columnToArg(table.ColumnInfo[i].GetMysqlType(), col) + dml.OldValues[name] = arg + } + } + } + } + } + return +} + +func columnToArg(mysqlType string, c *pb.Column) (arg interface{}) { + if c.GetIsNull() { + return nil + } + + if c.Int64Value != nil { + return c.GetInt64Value() + } + + if c.Uint64Value != nil { + return c.GetUint64Value() + } + + if c.DoubleValue != nil { + return c.GetDoubleValue() + } + + if c.BytesValue != nil { + // https://github.com/go-sql-driver/mysql/issues/819 + // for downstream = mysql + // it work for tidb to use binary + if mysqlType == "json" { + var str string = string(c.GetBytesValue()) + return str + } + return c.GetBytesValue() + } + + return c.GetStringValue() +} diff --git a/pkg/loader/util.go b/pkg/loader/util.go new file mode 100644 index 000000000..cd16bf8c6 --- /dev/null +++ b/pkg/loader/util.go @@ -0,0 +1,210 @@ +package loader + +import ( + gosql "database/sql" + "fmt" + "hash/crc32" + "strings" + + "github.com/juju/errors" +) + +type tableInfo struct { + columns []string + primaryKey *indexInfo + // include primary key if have + uniqueKeys []indexInfo +} + +type indexInfo struct { + name string + columns []string +} + +// getTableInfo return the table info +// https://dev.mysql.com/doc/refman/8.0/en/show-columns.html +// https://dev.mysql.com/doc/refman/8.0/en/show-index.html +func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, err error) { + info = new(tableInfo) + + // get column info + // + // mysql> SHOW COLUMNS FROM City; + // +-------------+----------+------+-----+---------+----------------+ + // | Field | Type | Null | Key | Default | Extra | + // +-------------+----------+------+-----+---------+----------------+ + // | ID | int(11) | NO | PRI | NULL | auto_increment | + // | Name | char(35) | NO | | | | + // | CountryCode | char(3) | NO | MUL | | | + // | District | char(20) | NO | | | | + // | Population | int(11) | NO | | 0 | | + // +-------------+----------+------+-----+---------+----------------+ + sql := fmt.Sprintf("show columns from %s", quoteSchema(schema, table)) + rows, err := db.Query(sql) + if err != nil { + return nil, errors.Trace(err) + } + + defer rows.Close() + + for rows.Next() { + cols := make([]interface{}, 6) + var name string + cols[0] = &name + for i := 1; i < len(cols); i++ { + cols[i] = &gosql.RawBytes{} + } + + err = rows.Scan(cols...) + if err != nil { + return nil, errors.Trace(err) + } + + info.columns = append(info.columns, name) + } + + if err = rows.Err(); err != nil { + return nil, errors.Trace(err) + } + + // get index info + // + // mysql> show index from a; + // +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ + // | Table | Non_unique | Key_name | Seq_in_index | Column_name | Collation | Cardinality | Sub_part | Packed | Null | Index_type | Comment | Index_comment | + // +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ + // | a | 0 | PRIMARY | 1 | id | A | 0 | NULL | NULL | | BTREE | | | + // | a | 1 | a1 | 1 | a1 | A | 0 | NULL | NULL | YES | BTREE | | | + // +-------+------------+----------+--------------+-------------+-----------+-------------+----------+--------+------+------------+---------+---------------+ + sql = fmt.Sprintf("show index from %s", quoteSchema(schema, table)) + rows, err = db.Query(sql) + if err != nil { + return nil, errors.Trace(err) + } + + defer rows.Close() + + // get pk and uk + // key for PRIMARY or other index name + for rows.Next() { + cols := make([]interface{}, 13) + for i := 0; i < len(cols); i++ { + cols[i] = &gosql.RawBytes{} + } + + var nonUnique int + var keyName string + var columnName string + var seqInIndex int // start at 1 + cols[1] = &nonUnique + cols[2] = &keyName + cols[3] = &seqInIndex + cols[4] = &columnName + + err = rows.Scan(cols...) + if err != nil { + return nil, errors.Trace(err) + } + + // log.Debug(nonUnique, keyName, columnName) + if nonUnique == 1 { + continue + } + + var i int + // set columns in the order by Seq_In_Index + for i = 0; i < len(info.uniqueKeys); i++ { + if info.uniqueKeys[i].name == keyName { + // expand columns size + for seqInIndex > len(info.uniqueKeys[i].columns) { + info.uniqueKeys[i].columns = append(info.uniqueKeys[i].columns, "") + } + info.uniqueKeys[i].columns[seqInIndex-1] = columnName + break + } + } + if i == len(info.uniqueKeys) { + info.uniqueKeys = append(info.uniqueKeys, indexInfo{keyName, []string{columnName}}) + } + + } + + // put primary key at first place + // and set primaryKey + for i := 0; i < len(info.uniqueKeys); i++ { + if info.uniqueKeys[i].name == "PRIMARY" { + info.uniqueKeys[i], info.uniqueKeys[0] = info.uniqueKeys[0], info.uniqueKeys[i] + info.primaryKey = &info.uniqueKeys[0] + break + } + } + + if err = rows.Err(); err != nil { + return nil, errors.Trace(err) + } + + return +} + +// CreateDB return sql.DB +func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port) + + db, err = gosql.Open("mysql", dsn) + if err != nil { + return nil, errors.Trace(err) + } + return +} + +func quoteSchema(schema string, table string) string { + return fmt.Sprintf("`%s`.`%s`", escapeName(schema), escapeName(table)) +} + +func quoteName(name string) string { + return "`" + escapeName(name) + "`" +} + +func escapeName(name string) string { + return strings.Replace(name, "`", "``", -1) +} + +func holderString(n int) string { + builder := new(strings.Builder) + for i := 0; i < n; i++ { + if i > 0 { + builder.WriteString(",") + } + builder.WriteString("?") + } + return builder.String() +} + +func genHashKey(key string) uint32 { + return crc32.ChecksumIEEE([]byte(key)) +} + +func splitDMLs(dmls []*DML, size int) (res [][]*DML) { + for i := 0; i < len(dmls); i += size { + end := i + size + if end > len(dmls) { + end = len(dmls) + } + + res = append(res, dmls[i:end]) + } + return +} + +func buildColumnList(names []string) string { + b := new(strings.Builder) + for i, name := range names { + if i > 0 { + b.WriteString(",") + } + b.WriteString(quoteName(name)) + + } + + return b.String() +} diff --git a/pkg/loader/util_test.go b/pkg/loader/util_test.go new file mode 100644 index 000000000..96b2cdc3f --- /dev/null +++ b/pkg/loader/util_test.go @@ -0,0 +1,57 @@ +package loader + +import ( + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + check "github.com/pingcap/check" +) + +func Test(t *testing.T) { check.TestingT(t) } + +type UtilSuite struct{} + +var _ = check.Suite(&UtilSuite{}) + +func (cs *UtilSuite) SetUpTest(c *check.C) { +} + +func (cs *UtilSuite) TestGetTableInfo(c *check.C) { + db, mock, err := sqlmock.New() + + c.Assert(err, check.IsNil) + defer db.Close() + + // (id, a1, a2, a3, a4) + // primary key: id + // unique key: (a1) (a2,a3) + columnRows := sqlmock.NewRows([]string{"Field", "Type", "Null", "Key", "Default", "Extra"}). + AddRow("id", "int(11)", "NO", "PRI", "NULL", ""). + AddRow("a1", "int(11)", "NO", "PRI", "NULL", ""). + AddRow("a2", "int(11)", "NO", "PRI", "NULL", ""). + AddRow("a3", "int(11)", "NO", "PRI", "NULL", ""). + AddRow("a4", "int(11)", "NO", "PRI", "NULL", "") + + indexRows := sqlmock.NewRows([]string{"Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", "Collation", "Cardinality", "Sub_part", "Packed", "Null", "Index_type", "Comment", "Index_comment"}). + AddRow("test1", 0, "PRIMARY", 1, "id", "", "", "", "", "", "", "", ""). + AddRow("test1", 0, "dex1", 1, "a1", "", "", "", "", "", "", "", ""). + AddRow("test1", 0, "dex2", 1, "a2", "", "", "", "", "", "", "", ""). + AddRow("test1", 0, "dex2", 2, "a3", "", "", "", "", "", "", "", ""). + AddRow("test1", 1, "dex3", 1, "a4", "", "", "", "", "", "", "", "") + + mock.ExpectQuery("show columns").WillReturnRows(columnRows) + + mock.ExpectQuery("show index").WillReturnRows(indexRows) + + info, err := getTableInfo(db, "test", "test1") + c.Assert(err, check.IsNil) + c.Assert(info, check.NotNil) + + c.Assert(info, check.DeepEquals, &tableInfo{ + columns: []string{"id", "a1", "a2", "a3", "a4"}, + primaryKey: &indexInfo{"PRIMARY", []string{"id"}}, + uniqueKeys: []indexInfo{{"PRIMARY", []string{"id"}}, + {"dex1", []string{"a1"}}, + {"dex2", []string{"a2", "a3"}}, + }}) +} diff --git a/tests/dailytest/case.go b/tests/dailytest/case.go index 821ff860d..355155a29 100644 --- a/tests/dailytest/case.go +++ b/tests/dailytest/case.go @@ -2,8 +2,12 @@ package dailytest import ( "database/sql" + "fmt" + "math/rand" "strings" + "time" + "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb-binlog/diff" ) @@ -144,6 +148,18 @@ func RunCase(cfg *diff.Config, src *sql.DB, dst *sql.DB) { } }) + // random op on have both pk and uk table + RunTest(cfg, src, dst, func(src *sql.DB) { + start := time.Now() + + err := updatePKUK(src, 1000) + if err != nil { + log.Fatal(errors.ErrorStack(err)) + } + + log.Info(" updatePKUK take: ", time.Since(start)) + }) + // clean table RunTest(cfg, src, dst, func(src *sql.DB) { err := execSQLs(src, case3Clean) @@ -152,6 +168,49 @@ func RunCase(cfg *diff.Config, src *sql.DB, dst *sql.DB) { } }) + // swap unique index value + RunTest(cfg, src, dst, func(src *sql.DB) { + _, err := src.Exec("create table uindex(id int primary key, a1 int unique)") + if err != nil { + log.Fatal(err) + } + + _, err = src.Exec("insert into uindex(id, a1) values(1, 10),(2,20)") + if err != nil { + log.Fatal(err) + } + + tx, err := src.Begin() + if err != nil { + log.Fatal(err) + } + + _, err = tx.Exec("update uindex set a1 = 30 where id = 1") + if err != nil { + log.Fatal(err) + } + + _, err = tx.Exec("update uindex set a1 = 10 where id = 2") + if err != nil { + log.Fatal(err) + } + + _, err = tx.Exec("update uindex set a1 = 20 where id = 1") + if err != nil { + log.Fatal(err) + } + + err = tx.Commit() + if err != nil { + log.Fatal(err) + } + + _, err = src.Exec("drop table uindex") + if err != nil { + log.Fatal(err) + } + }) + // test big binlog msg RunTest(cfg, src, dst, func(src *sql.DB) { _, err := src.Query("create table binlog_big(id int primary key, data longtext);") @@ -186,3 +245,102 @@ func RunCase(cfg *diff.Config, src *sql.DB, dst *sql.DB) { }) } + +// updatePKUK create a table with primary key and unique key +// then do opNum randomly DML +func updatePKUK(db *sql.DB, opNum int) error { + maxKey := 20 + _, err := db.Exec("create table pkuk(pk int primary key, uk int, v int, unique key uk(uk));") + if err != nil { + return errors.Trace(err) + } + + var pks []int + addPK := func(pk int) { + pks = append(pks, pk) + } + removePK := func(pk int) { + var tmp []int + for _, v := range pks { + if v != pk { + tmp = append(tmp, v) + } + } + pks = tmp + } + hasPK := func(pk int) bool { + for _, v := range pks { + if v == pk { + return true + } + } + return false + } + + for i := 0; i < opNum; { + var sql string + pk := rand.Intn(maxKey) + uk := rand.Intn(maxKey) + v := rand.Intn(10000) + oldPK := rand.Intn(maxKey) + + // try randomly insert&update&delete + op := rand.Intn(3) + switch op { + case 0: + if len(pks) == maxKey { + continue + } + for hasPK(pk) { + log.Info(pks) + pk = rand.Intn(maxKey) + } + sql = fmt.Sprintf("insert into pkuk(pk, uk, v) values(%d,%d,%d)", pk, uk, v) + case 1: + if len(pks) == 0 { + continue + } + for !hasPK(oldPK) { + log.Info(pks) + oldPK = rand.Intn(maxKey) + } + sql = fmt.Sprintf("update pkuk set pk = %d, uk = %d, v = %d where pk = %d", pk, uk, v, oldPK) + case 2: + if len(pks) == 0 { + continue + } + for !hasPK(pk) { + log.Info(pks) + pk = rand.Intn(maxKey) + } + sql = fmt.Sprintf("delete from pkuk where pk = %d", pk) + } + + _, err := db.Exec(sql) + if err != nil { + // for insert and update, we didn't check for uk's duplicate + if strings.Contains(err.Error(), "Duplicate entry") { + continue + } + return errors.Trace(err) + } + + switch op { + case 0: + addPK(pk) + case 1: + removePK(oldPK) + addPK(pk) + case 2: + removePK(pk) + } + i++ + } + + _, err = db.Exec("drop table pkuk") + if err != nil { + return errors.Trace(err) + } + + return nil +} diff --git a/tests/kafka/kafka.go b/tests/kafka/kafka.go index a05fd2940..62580a869 100644 --- a/tests/kafka/kafka.go +++ b/tests/kafka/kafka.go @@ -2,7 +2,6 @@ package main import ( "flag" - "fmt" "strings" "time" @@ -11,17 +10,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb-binlog/diff" + "github.com/pingcap/tidb-binlog/pkg/loader" "github.com/pingcap/tidb-binlog/tests/dailytest" "github.com/pingcap/tidb-binlog/tests/util" "github.com/pingcap/tidb-tools/tidb-binlog/driver/reader" - pb "github.com/pingcap/tidb-tools/tidb-binlog/slave_binlog_proto/go-binlog" - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/parser" ) // drainer -> kafka, syn data from kafka to downstream TiDB, and run the dailytest -// most copy from github.com/pingcap/tidb-tools/tidb-binlog/driver/example/mysql/mysql.go -// TODO maybe later we can replace by the `new tool` package var ( kafkaAddr = flag.String("kafkaAddr", "127.0.0.1:9092", "kafkaAddr like 127.0.0.1:9092,127.0.0.1:9093") @@ -32,6 +27,7 @@ var ( func main() { flag.Parse() + log.Debug("start run kafka test...") cfg := &reader.Config{ KafkaAddr: strings.Split(*kafkaAddr, ","), @@ -56,7 +52,21 @@ func main() { } // start sync to mysql from kafka + ld, err := loader.NewLoader(sinkDB, loader.WorkerCount(16), loader.BatchSize(128)) + if err != nil { + panic(err) + } + + go func() { + err := ld.Run() + if err != nil { + log.Error(errors.ErrorStack(err)) + log.Fatal(err) + } + }() + go func() { + defer ld.Close() for { select { case msg := <-breader.Messages(): @@ -66,25 +76,9 @@ func main() { } log.Debug("recv: ", str) binlog := msg.Binlog - sqls, args := toSQL(binlog) - - tx, err := sinkDB.Begin() - if err != nil { - log.Fatal(err) - } - - for i := 0; i < len(sqls); i++ { - // log.Debug("exec: args: ", sqls[i], args[i]) - _, err = tx.Exec(sqls[i], args[i]...) - if err != nil { - tx.Rollback() - log.Fatal(err) - } - } - err = tx.Commit() - if err != nil { - log.Fatal(err) - } + ld.Input() <- loader.SlaveBinlogToTxn(binlog) + case txn := <-ld.Successes(): + log.Debug("succ: ", txn) } } }() @@ -100,204 +94,3 @@ func main() { } dailytest.Run(sourceDB, sinkDB, diffCfg, 10, 1000, 10) } - -func columnToArg(c *pb.Column) (arg interface{}) { - if c.GetIsNull() { - return nil - } - - if c.Int64Value != nil { - return c.GetInt64Value() - } - - if c.Uint64Value != nil { - return c.GetUint64Value() - } - - if c.DoubleValue != nil { - return c.GetDoubleValue() - } - - if c.BytesValue != nil { - return c.GetBytesValue() - } - - return c.GetStringValue() -} - -func tableToSQL(table *pb.Table) (sqls []string, sqlArgs [][]interface{}) { - replace := func(row *pb.Row) { - sql := fmt.Sprintf("replace into `%s`.`%s`", table.GetSchemaName(), table.GetTableName()) - - var names []string - var placeHolders []string - for _, c := range table.GetColumnInfo() { - names = append(names, c.GetName()) - placeHolders = append(placeHolders, "?") - } - sql += "(" + strings.Join(names, ",") + ")" - sql += "values(" + strings.Join(placeHolders, ",") + ")" - - var args []interface{} - for _, col := range row.GetColumns() { - args = append(args, columnToArg(col)) - } - - sqls = append(sqls, sql) - sqlArgs = append(sqlArgs, args) - } - - constructWhere := func(args []interface{}) (sql string, usePK bool) { - var whereColumns []string - var whereArgs []interface{} - for i, col := range table.GetColumnInfo() { - if col.GetIsPrimaryKey() { - whereColumns = append(whereColumns, col.GetName()) - whereArgs = append(whereArgs, args[i]) - usePK = true - } - } - // no primary key - if len(whereColumns) == 0 { - for i, col := range table.GetColumnInfo() { - whereColumns = append(whereColumns, col.GetName()) - whereArgs = append(whereArgs, args[i]) - } - } - - sql = " where " - for i, col := range whereColumns { - if i != 0 { - sql += " and " - } - - if whereArgs[i] == nil { - sql += fmt.Sprintf("%s IS NULL ", col) - } else { - sql += fmt.Sprintf("%s = ? ", col) - } - } - - sql += " limit 1" - - return - } - - for _, mutation := range table.Mutations { - switch mutation.GetType() { - case pb.MutationType_Insert: - replace(mutation.Row) - case pb.MutationType_Update: - columnInfo := table.GetColumnInfo() - sql := fmt.Sprintf("update `%s`.`%s` set ", table.GetSchemaName(), table.GetTableName()) - // construct c1 = ?, c2 = ?... - for i, col := range columnInfo { - if i != 0 { - sql += "," - } - sql += fmt.Sprintf("%s = ? ", col.Name) - } - - row := mutation.Row - changedRow := mutation.ChangeRow - - var args []interface{} - // for set - for _, col := range row.GetColumns() { - args = append(args, columnToArg(col)) - } - - where, usePK := constructWhere(args) - sql += where - - // for where - for i, col := range changedRow.GetColumns() { - if columnToArg(col) == nil { - continue - } - if !usePK || columnInfo[i].GetIsPrimaryKey() { - args = append(args, columnToArg(col)) - } - } - - sqls = append(sqls, sql) - sqlArgs = append(sqlArgs, args) - - case pb.MutationType_Delete: - columnInfo := table.GetColumnInfo() - row := mutation.Row - - var values []interface{} - for _, col := range row.GetColumns() { - values = append(values, columnToArg(col)) - } - where, usePK := constructWhere(values) - - sql := fmt.Sprintf("delete from `%s`.`%s` %s", table.GetSchemaName(), table.GetTableName(), where) - - var args []interface{} - for i, col := range row.GetColumns() { - if columnToArg(col) == nil { - continue - } - if !usePK || columnInfo[i].GetIsPrimaryKey() { - args = append(args, columnToArg(col)) - } - } - - sqls = append(sqls, sql) - sqlArgs = append(sqlArgs, args) - } - } - - return -} - -func isCreateDatabase(sql string) (isCreateDatabase bool, err error) { - if !strings.Contains(strings.ToLower(sql), "database") { - return false, nil - } - - stmt, err := parser.New().ParseOneStmt(sql, "", "") - if err != nil { - return false, errors.Annotate(err, fmt.Sprintf("parse: %s", sql)) - } - - _, isCreateDatabase = stmt.(*ast.CreateDatabaseStmt) - - return -} - -func toSQL(binlog *pb.Binlog) ([]string, [][]interface{}) { - var allSQL []string - var allArgs [][]interface{} - - switch binlog.GetType() { - case pb.BinlogType_DDL: - ddl := binlog.DdlData - isCreateDatabase, err := isCreateDatabase(string(ddl.DdlQuery)) - if err != nil { - log.Fatal(errors.ErrorStack(err)) - } - if !isCreateDatabase { - sql := fmt.Sprintf("use %s", ddl.GetSchemaName()) - allSQL = append(allSQL, sql) - allArgs = append(allArgs, nil) - } - allSQL = append(allSQL, string(ddl.DdlQuery)) - allArgs = append(allArgs, nil) - - case pb.BinlogType_DML: - dml := binlog.DmlData - for _, table := range dml.GetTables() { - sqls, sqlArgs := tableToSQL(table) - allSQL = append(allSQL, sqls...) - allArgs = append(allArgs, sqlArgs...) - } - - default: - log.Fatal("unknown type: ", binlog.GetType()) - } - - return allSQL, allArgs -} diff --git a/tests/util/db.go b/tests/util/db.go index 3affa5031..707626325 100644 --- a/tests/util/db.go +++ b/tests/util/db.go @@ -38,7 +38,7 @@ func CreateDB(cfg DBConfig) (*sql.DB, error) { zone, offset := time.Now().Zone() zone = fmt.Sprintf("'+%02d:00'", offset/3600) - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&time_zone=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, url.QueryEscape(zone)) + dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&interpolateParams=true&multiStatements=true&time_zone=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, url.QueryEscape(zone)) db, err := sql.Open("mysql", dbDSN) if err != nil { return nil, errors.Trace(err) diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/LICENSE b/vendor/github.com/DATA-DOG/go-sqlmock/LICENSE new file mode 100644 index 000000000..6ee063ce7 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/LICENSE @@ -0,0 +1,28 @@ +The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) + +Copyright (c) 2013-2019, DATA-DOG team +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name DataDog.lt may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/argument.go b/vendor/github.com/DATA-DOG/go-sqlmock/argument.go new file mode 100644 index 000000000..7727481a8 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/argument.go @@ -0,0 +1,24 @@ +package sqlmock + +import "database/sql/driver" + +// Argument interface allows to match +// any argument in specific way when used with +// ExpectedQuery and ExpectedExec expectations. +type Argument interface { + Match(driver.Value) bool +} + +// AnyArg will return an Argument which can +// match any kind of arguments. +// +// Useful for time.Time or similar kinds of arguments. +func AnyArg() Argument { + return anyArgument{} +} + +type anyArgument struct{} + +func (a anyArgument) Match(_ driver.Value) bool { + return true +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/driver.go b/vendor/github.com/DATA-DOG/go-sqlmock/driver.go new file mode 100644 index 000000000..802f8fbe7 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/driver.go @@ -0,0 +1,81 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "sync" +) + +var pool *mockDriver + +func init() { + pool = &mockDriver{ + conns: make(map[string]*sqlmock), + } + sql.Register("sqlmock", pool) +} + +type mockDriver struct { + sync.Mutex + counter int + conns map[string]*sqlmock +} + +func (d *mockDriver) Open(dsn string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + + c, ok := d.conns[dsn] + if !ok { + return c, fmt.Errorf("expected a connection to be available, but it is not") + } + + c.opened++ + return c, nil +} + +// New creates sqlmock database connection and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. +// Pings db so that all expectations could be +// asserted. +func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + pool.Lock() + dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) + pool.counter++ + + smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} + pool.conns[dsn] = smock + pool.Unlock() + + return smock.open(options) +} + +// NewWithDSN creates sqlmock database connection with a specific DSN +// and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. +// Pings db so that all expectations could be asserted. +// +// This method is introduced because of sql abstraction +// libraries, which do not provide a way to initialize +// with sql.DB instance. For example GORM library. +// +// Note, it will error if attempted to create with an +// already used dsn +// +// It is not recommended to use this method, unless you +// really need it and there is no other way around. +func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + pool.Lock() + if _, ok := pool.conns[dsn]; ok { + pool.Unlock() + return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn) + } + smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} + pool.conns[dsn] = smock + pool.Unlock() + + return smock.open(options) +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/expectations.go b/vendor/github.com/DATA-DOG/go-sqlmock/expectations.go new file mode 100644 index 000000000..27a716e7d --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/expectations.go @@ -0,0 +1,355 @@ +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "strings" + "sync" + "time" +) + +// an expectation interface +type expectation interface { + fulfilled() bool + Lock() + Unlock() + String() string +} + +// common expectation struct +// satisfies the expectation interface +type commonExpectation struct { + sync.Mutex + triggered bool + err error +} + +func (e *commonExpectation) fulfilled() bool { + return e.triggered +} + +// ExpectedClose is used to manage *sql.DB.Close expectation +// returned by *Sqlmock.ExpectClose. +type ExpectedClose struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.DB.Close action +func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedClose) String() string { + msg := "ExpectedClose => expecting database Close" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedBegin is used to manage *sql.DB.Begin expectation +// returned by *Sqlmock.ExpectBegin. +type ExpectedBegin struct { + commonExpectation + delay time.Duration +} + +// WillReturnError allows to set an error for *sql.DB.Begin action +func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedBegin) String() string { + msg := "ExpectedBegin => expecting database transaction Begin" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { + e.delay = duration + return e +} + +// ExpectedCommit is used to manage *sql.Tx.Commit expectation +// returned by *Sqlmock.ExpectCommit. +type ExpectedCommit struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Close action +func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedCommit) String() string { + msg := "ExpectedCommit => expecting transaction Commit" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedRollback is used to manage *sql.Tx.Rollback expectation +// returned by *Sqlmock.ExpectRollback. +type ExpectedRollback struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Rollback action +func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedRollback) String() string { + msg := "ExpectedRollback => expecting transaction Rollback" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} + +// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query, +// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations. +// Returned by *Sqlmock.ExpectQuery. +type ExpectedQuery struct { + queryBasedExpectation + rows driver.Rows + delay time.Duration + rowsMustBeClosed bool + rowsWereClosed bool +} + +// WithArgs will match given expected args to actual database query arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { + e.args = args + return e +} + +// RowsWillBeClosed expects this query rows to be closed. +func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { + e.rowsMustBeClosed = true + return e +} + +// WillReturnError allows to set an error for expected database query +func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { + e.err = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { + e.delay = duration + return e +} + +// String returns string representation +func (e *ExpectedQuery) String() string { + msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if len(e.args) == 0 { + msg += "\n - is without arguments" + } else { + msg += "\n - is with arguments:\n" + for i, arg := range e.args { + msg += fmt.Sprintf(" %d - %+v\n", i, arg) + } + msg = strings.TrimSpace(msg) + } + + if e.rows != nil { + msg += fmt.Sprintf("\n - %s", e.rows) + } + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + return msg +} + +// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. +// Returned by *Sqlmock.ExpectExec. +type ExpectedExec struct { + queryBasedExpectation + result driver.Result + delay time.Duration +} + +// WithArgs will match given expected args to actual database exec operation arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { + e.args = args + return e +} + +// WillReturnError allows to set an error for expected database exec action +func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { + e.err = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { + e.delay = duration + return e +} + +// String returns string representation +func (e *ExpectedExec) String() string { + msg := "ExpectedExec => expecting Exec or ExecContext which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if len(e.args) == 0 { + msg += "\n - is without arguments" + } else { + msg += "\n - is with arguments:\n" + var margs []string + for i, arg := range e.args { + margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg)) + } + msg += strings.Join(margs, "\n") + } + + if e.result != nil { + res, _ := e.result.(*result) + msg += "\n - should return Result having:" + msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID) + msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected) + if res.err != nil { + msg += fmt.Sprintf("\n Error: %s", res.err) + } + } + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + return msg +} + +// WillReturnResult arranges for an expected Exec() to return a particular +// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method +// to build a corresponding result. Or if actions needs to be tested against errors +// sqlmock.NewErrorResult(err error) to return a given error. +func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { + e.result = result + return e +} + +// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. +// Returned by *Sqlmock.ExpectPrepare. +type ExpectedPrepare struct { + commonExpectation + mock *sqlmock + expectSQL string + statement driver.Stmt + closeErr error + mustBeClosed bool + wasClosed bool + delay time.Duration +} + +// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. +func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { + e.err = err + return e +} + +// WillReturnCloseError allows to set an error for this prepared statement Close action +func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { + e.closeErr = err + return e +} + +// WillDelayFor allows to specify duration for which it will delay +// result. May be used together with Context +func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { + e.delay = duration + return e +} + +// WillBeClosed expects this prepared statement to +// be closed. +func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { + e.mustBeClosed = true + return e +} + +// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. +// this method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { + eq := &ExpectedQuery{} + eq.expectSQL = e.expectSQL + eq.converter = e.mock.converter + e.mock.expected = append(e.mock.expected, eq) + return eq +} + +// ExpectExec allows to expect Exec() on this prepared statement. +// this method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { + eq := &ExpectedExec{} + eq.expectSQL = e.expectSQL + eq.converter = e.mock.converter + e.mock.expected = append(e.mock.expected, eq) + return eq +} + +// String returns string representation +func (e *ExpectedPrepare) String() string { + msg := "ExpectedPrepare => expecting Prepare statement which:" + msg += "\n - matches sql: '" + e.expectSQL + "'" + + if e.err != nil { + msg += fmt.Sprintf("\n - should return error: %s", e.err) + } + + if e.closeErr != nil { + msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr) + } + + return msg +} + +// query based expectation +// adds a query matching logic +type queryBasedExpectation struct { + commonExpectation + expectSQL string + converter driver.ValueConverter + args []driver.Value +} + +func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) { + // catch panic + defer func() { + if e := recover(); e != nil { + _, ok := e.(error) + if !ok { + err = fmt.Errorf(e.(string)) + } + } + }() + + err = e.argsMatches(args) + return +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/expectations_before_go18.go b/vendor/github.com/DATA-DOG/go-sqlmock/expectations_before_go18.go new file mode 100644 index 000000000..e368e0405 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/expectations_before_go18.go @@ -0,0 +1,52 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { + e.rows = &rowSets{sets: []*Rows{rows}, ex: e} + return e +} + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + // convert to driver converter + darg, err := e.converter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/expectations_go18.go b/vendor/github.com/DATA-DOG/go-sqlmock/expectations_go18.go new file mode 100644 index 000000000..2d5ccba0d --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/expectations_go18.go @@ -0,0 +1,66 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + sets := make([]*Rows, len(rows)) + for i, r := range rows { + sets[i] = r + } + e.rows = &rowSets{sets: sets, ex: e} + return e +} + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + // @TODO should we assert either all args are named or ordinal? + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + if named, isNamed := dval.(sql.NamedArg); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + } else if k+1 != v.Ordinal { + return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal) + } + + // convert to driver converter + darg, err := e.converter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/options.go b/vendor/github.com/DATA-DOG/go-sqlmock/options.go new file mode 100644 index 000000000..29053eeea --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/options.go @@ -0,0 +1,22 @@ +package sqlmock + +import "database/sql/driver" + +// ValueConverterOption allows to create a sqlmock connection +// with a custom ValueConverter to support drivers with special data types. +func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error { + return func(s *sqlmock) error { + s.converter = converter + return nil + } +} + +// QueryMatcherOption allows to customize SQL query matcher +// and match SQL query strings in more sophisticated ways. +// The default QueryMatcher is QueryMatcherRegexp. +func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { + return func(s *sqlmock) error { + s.queryMatcher = queryMatcher + return nil + } +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/query.go b/vendor/github.com/DATA-DOG/go-sqlmock/query.go new file mode 100644 index 000000000..8e0584801 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/query.go @@ -0,0 +1,68 @@ +package sqlmock + +import ( + "fmt" + "regexp" + "strings" +) + +var re = regexp.MustCompile("\\s+") + +// strip out new lines and trim spaces +func stripQuery(q string) (s string) { + return strings.TrimSpace(re.ReplaceAllString(q, " ")) +} + +// QueryMatcher is an SQL query string matcher interface, +// which can be used to customize validation of SQL query strings. +// As an exaple, external library could be used to build +// and validate SQL ast, columns selected. +// +// sqlmock can be customized to implement a different QueryMatcher +// configured through an option when sqlmock.New or sqlmock.NewWithDSN +// is called, default QueryMatcher is QueryMatcherRegexp. +type QueryMatcher interface { + + // Match expected SQL query string without whitespace to + // actual SQL. + Match(expectedSQL, actualSQL string) error +} + +// QueryMatcherFunc type is an adapter to allow the use of +// ordinary functions as QueryMatcher. If f is a function +// with the appropriate signature, QueryMatcherFunc(f) is a +// QueryMatcher that calls f. +type QueryMatcherFunc func(expectedSQL, actualSQL string) error + +// Match implements the QueryMatcher +func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { + return f(expectedSQL, actualSQL) +} + +// QueryMatcherRegexp is the default SQL query matcher +// used by sqlmock. It parses expectedSQL to a regular +// expression and attempts to match actualSQL. +var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + re, err := regexp.Compile(expect) + if err != nil { + return err + } + if !re.MatchString(actual) { + return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String()) + } + return nil +}) + +// QueryMatcherEqual is the SQL query matcher +// which simply tries a case sensitive match of +// expected and actual SQL strings without whitespace. +var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + if actual != expect { + return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect) + } + return nil +}) diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/result.go b/vendor/github.com/DATA-DOG/go-sqlmock/result.go new file mode 100644 index 000000000..a63e72ba8 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/result.go @@ -0,0 +1,39 @@ +package sqlmock + +import ( + "database/sql/driver" +) + +// Result satisfies sql driver Result, which +// holds last insert id and rows affected +// by Exec queries +type result struct { + insertID int64 + rowsAffected int64 + err error +} + +// NewResult creates a new sql driver Result +// for Exec based query mocks. +func NewResult(lastInsertID int64, rowsAffected int64) driver.Result { + return &result{ + insertID: lastInsertID, + rowsAffected: rowsAffected, + } +} + +// NewErrorResult creates a new sql driver Result +// which returns an error given for both interface methods +func NewErrorResult(err error) driver.Result { + return &result{ + err: err, + } +} + +func (r *result) LastInsertId() (int64, error) { + return r.insertID, r.err +} + +func (r *result) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/rows.go b/vendor/github.com/DATA-DOG/go-sqlmock/rows.go new file mode 100644 index 000000000..0244bd4bb --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/rows.go @@ -0,0 +1,176 @@ +package sqlmock + +import ( + "database/sql/driver" + "encoding/csv" + "fmt" + "io" + "strings" +) + +// CSVColumnParser is a function which converts trimmed csv +// column string to a []byte representation. currently +// transforms NULL to nil +var CSVColumnParser = func(s string) []byte { + switch { + case strings.ToLower(s) == "null": + return nil + } + return []byte(s) +} + +type rowSets struct { + sets []*Rows + pos int + ex *ExpectedQuery +} + +func (rs *rowSets) Columns() []string { + return rs.sets[rs.pos].cols +} + +func (rs *rowSets) Close() error { + rs.ex.rowsWereClosed = true + return rs.sets[rs.pos].closeErr +} + +// advances to next row +func (rs *rowSets) Next(dest []driver.Value) error { + r := rs.sets[rs.pos] + r.pos++ + if r.pos > len(r.rows) { + return io.EOF // per interface spec + } + + for i, col := range r.rows[r.pos-1] { + dest[i] = col + } + + return r.nextErr[r.pos-1] +} + +// transforms to debuggable printable string +func (rs *rowSets) String() string { + if rs.empty() { + return "with empty rows" + } + + msg := "should return rows:\n" + if len(rs.sets) == 1 { + for n, row := range rs.sets[0].rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + return strings.TrimSpace(msg) + } + for i, set := range rs.sets { + msg += fmt.Sprintf(" result set: %d\n", i) + for n, row := range set.rows { + msg += fmt.Sprintf(" row %d - %+v\n", n, row) + } + } + return strings.TrimSpace(msg) +} + +func (rs *rowSets) empty() bool { + for _, set := range rs.sets { + if len(set.rows) > 0 { + return false + } + } + return true +} + +// Rows is a mocked collection of rows to +// return for Query result +type Rows struct { + converter driver.ValueConverter + cols []string + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error +} + +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and +// to be used as sql driver.Rows. +// Use Sqlmock.NewRows instead if using a custom converter +func NewRows(columns []string) *Rows { + return &Rows{ + cols: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } +} + +// CloseError allows to set an error +// which will be returned by rows.Close +// function. +// +// The close error will be triggered only in cases +// when rows.Next() EOF was not yet reached, that is +// a default sql library behavior +func (r *Rows) CloseError(err error) *Rows { + r.closeErr = err + return r +} + +// RowError allows to set an error +// which will be returned when a given +// row number is read +func (r *Rows) RowError(row int, err error) *Rows { + r.nextErr[row] = err + return r +} + +// AddRow composed from database driver.Value slice +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) AddRow(values ...driver.Value) *Rows { + if len(values) != len(r.cols) { + panic("Expected number of values to match number of columns") + } + + row := make([]driver.Value, len(r.cols)) + for i, v := range values { + // Convert user-friendly values (such as int or driver.Valuer) + // to database/sql native value (driver.Value such as int64) + var err error + v, err = r.converter.ConvertValue(v) + if err != nil { + panic(fmt.Errorf( + "row #%d, column #%d (%q) type %T: %s", + len(r.rows)+1, i, r.cols[i], values[i], err, + )) + } + + row[i] = v + } + + r.rows = append(r.rows, row) + return r +} + +// FromCSVString build rows from csv string. +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) FromCSVString(s string) *Rows { + res := strings.NewReader(strings.TrimSpace(s)) + csvReader := csv.NewReader(res) + + for { + res, err := csvReader.Read() + if err != nil || res == nil { + break + } + + row := make([]driver.Value, len(r.cols)) + for i, v := range res { + row[i] = CSVColumnParser(strings.TrimSpace(v)) + } + r.rows = append(r.rows, row) + } + return r +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/rows_go18.go b/vendor/github.com/DATA-DOG/go-sqlmock/rows_go18.go new file mode 100644 index 000000000..4ecf84e7e --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/rows_go18.go @@ -0,0 +1,20 @@ +// +build go1.8 + +package sqlmock + +import "io" + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) HasNextResultSet() bool { + return rs.pos+1 < len(rs.sets) +} + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) NextResultSet() error { + if !rs.HasNextResultSet() { + return io.EOF + } + + rs.pos++ + return nil +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock.go b/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock.go new file mode 100644 index 000000000..609dafd82 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock.go @@ -0,0 +1,589 @@ +/* +Package sqlmock is a mock library implementing sql driver. Which has one and only +purpose - to simulate any sql driver behavior in tests, without needing a real +database connection. It helps to maintain correct **TDD** workflow. + +It does not require any modifications to your source code in order to test +and mock database operations. Supports concurrency and multiple database mocking. + +The driver allows to mock any sql driver method behavior. +*/ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +// Sqlmock interface serves to create expectations +// for any kind of database action in order to mock +// and test real database behavior. +type Sqlmock interface { + + // ExpectClose queues an expectation for this database + // action to be triggered. the *ExpectedClose allows + // to mock database response + ExpectClose() *ExpectedClose + + // ExpectationsWereMet checks whether all queued expectations + // were met in order. If any of them was not met - an error is returned. + ExpectationsWereMet() error + + // ExpectPrepare expects Prepare() to be called with expectedSQL query. + // the *ExpectedPrepare allows to mock database response. + // Note that you may expect Query() or Exec() on the *ExpectedPrepare + // statement to prevent repeating expectedSQL + ExpectPrepare(expectedSQL string) *ExpectedPrepare + + // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query. + // the *ExpectedQuery allows to mock database response. + ExpectQuery(expectedSQL string) *ExpectedQuery + + // ExpectExec expects Exec() to be called with expectedSQL query. + // the *ExpectedExec allows to mock database response + ExpectExec(expectedSQL string) *ExpectedExec + + // ExpectBegin expects *sql.DB.Begin to be called. + // the *ExpectedBegin allows to mock database response + ExpectBegin() *ExpectedBegin + + // ExpectCommit expects *sql.Tx.Commit to be called. + // the *ExpectedCommit allows to mock database response + ExpectCommit() *ExpectedCommit + + // ExpectRollback expects *sql.Tx.Rollback to be called. + // the *ExpectedRollback allows to mock database response + ExpectRollback() *ExpectedRollback + + // MatchExpectationsInOrder gives an option whether to match all + // expectations in the order they were set or not. + // + // By default it is set to - true. But if you use goroutines + // to parallelize your query executation, that option may + // be handy. + // + // This option may be turned on anytime during tests. As soon + // as it is switched to false, expectations will be matched + // in any order. Or otherwise if switched to true, any unmatched + // expectations will be expected in order + MatchExpectationsInOrder(bool) + + // NewRows allows Rows to be created from a + // sql driver.Value slice or from the CSV string and + // to be used as sql driver.Rows. + NewRows(columns []string) *Rows +} + +type sqlmock struct { + ordered bool + dsn string + opened int + drv *mockDriver + converter driver.ValueConverter + queryMatcher QueryMatcher + + expected []expectation +} + +func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) { + db, err := sql.Open("sqlmock", c.dsn) + if err != nil { + return db, c, err + } + for _, option := range options { + err := option(c) + if err != nil { + return db, c, err + } + } + if c.converter == nil { + c.converter = driver.DefaultParameterConverter + } + if c.queryMatcher == nil { + c.queryMatcher = QueryMatcherRegexp + } + return db, c, db.Ping() +} + +func (c *sqlmock) ExpectClose() *ExpectedClose { + e := &ExpectedClose{} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) MatchExpectationsInOrder(b bool) { + c.ordered = b +} + +// Close a mock database driver connection. It may or may not +// be called depending on the sircumstances, but if it is called +// there must be an *ExpectedClose expectation satisfied. +// meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Close() error { + c.drv.Lock() + defer c.drv.Unlock() + + c.opened-- + if c.opened == 0 { + delete(c.drv.conns, c.dsn) + } + + var expected *ExpectedClose + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedClose); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next) + } + } + + if expected == nil { + msg := "call to database Close was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +func (c *sqlmock) ExpectationsWereMet() error { + for _, e := range c.expected { + e.Lock() + fulfilled := e.fulfilled() + e.Unlock() + + if !fulfilled { + return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) + } + + // for expected prepared statement check whether it was closed if expected + if prep, ok := e.(*ExpectedPrepare); ok { + if prep.mustBeClosed && !prep.wasClosed { + return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) + } + } + + // must check whether all expected queried rows are closed + if query, ok := e.(*ExpectedQuery); ok { + if query.rowsMustBeClosed && !query.rowsWereClosed { + return fmt.Errorf("expected query rows to be closed, but it was not: %s", query) + } + } + } + return nil +} + +// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Begin() (driver.Tx, error) { + ex, err := c.begin() + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *sqlmock) begin() (*ExpectedBegin, error) { + var expected *ExpectedBegin + var ok bool + var fulfilled int + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedBegin); ok { + break + } + + next.Unlock() + if c.ordered { + return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to database transaction Begin was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + + return expected, expected.err +} + +func (c *sqlmock) ExpectBegin() *ExpectedBegin { + e := &ExpectedBegin{} + c.expected = append(c.expected, e) + return e +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { + var expected *ExpectedExec + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { + msg := "call to ExecQuery '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.result == nil { + return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + + return expected, nil +} + +func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec { + e := &ExpectedExec{} + e.expectSQL = expectedSQL + e.converter = c.converter + c.expected = append(c.expected, e) + return e +} + +// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return &statement{c, ex, query}, nil +} + +func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { + var expected *ExpectedPrepare + var fulfilled int + var ok bool + + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedPrepare); ok { + break + } + + next.Unlock() + return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next) + } + + if pr, ok := next.(*ExpectedPrepare); ok { + if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil { + expected = pr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Prepare '%s' query was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query) + } + defer expected.Unlock() + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Prepare: %v", err) + } + + expected.triggered = true + return expected, expected.err +} + +func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare { + e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c} + c.expected = append(c.expected, e) + return e +} + +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { + var expected *ExpectedQuery + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { + expected = qr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Query '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.rows == nil { + return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + return expected, nil +} + +func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery { + e := &ExpectedQuery{} + e.expectSQL = expectedSQL + e.converter = c.converter + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectCommit() *ExpectedCommit { + e := &ExpectedCommit{} + c.expected = append(c.expected, e) + return e +} + +func (c *sqlmock) ExpectRollback() *ExpectedRollback { + e := &ExpectedRollback{} + c.expected = append(c.expected, e) + return e +} + +// Commit meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *sqlmock) Commit() error { + var expected *ExpectedCommit + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedCommit); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to Commit transaction was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *sqlmock) Rollback() error { + var expected *ExpectedRollback + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedRollback); ok { + break + } + + next.Unlock() + if c.ordered { + return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next) + } + } + if expected == nil { + msg := "call to Rollback transaction was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected.err +} + +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and +// to be used as sql driver.Rows. +func (c *sqlmock) NewRows(columns []string) *Rows { + r := NewRows(columns) + r.converter = c.converter + return r +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go18.go b/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go18.go new file mode 100644 index 000000000..0afb29682 --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go18.go @@ -0,0 +1,121 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql/driver" + "errors" + "time" +) + +// ErrCancelled defines an error value, which can be expected in case of +// such cancellation error. +var ErrCancelled = errors.New("canceling query due to user request") + +// Implement the "QueryerContext" interface +func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.rows, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ExecerContext" interface +func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.result, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ConnBeginTx" interface +func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + ex, err := c.begin() + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return c, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "ConnPrepareContext" interface +func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + ex, err := c.prepare(query) + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return &statement{c, ex, query}, nil + case <-ctx.Done(): + return nil, ErrCancelled + } + } + + return nil, err +} + +// Implement the "Pinger" interface +// for now we do not have a Ping expectation +// may be something for the future +func (c *sqlmock) Ping(ctx context.Context) error { + return nil +} + +// Implement the "StmtExecContext" interface +func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return stmt.conn.ExecContext(ctx, stmt.query, args) +} + +// Implement the "StmtQueryContext" interface +func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return stmt.conn.QueryContext(ctx, stmt.query, args) +} + +// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) + +// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker +func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = c.converter.ConvertValue(nv.Value) + return err +} diff --git a/vendor/github.com/DATA-DOG/go-sqlmock/statement.go b/vendor/github.com/DATA-DOG/go-sqlmock/statement.go new file mode 100644 index 000000000..570efd99a --- /dev/null +++ b/vendor/github.com/DATA-DOG/go-sqlmock/statement.go @@ -0,0 +1,28 @@ +package sqlmock + +import ( + "database/sql/driver" +) + +type statement struct { + conn *sqlmock + ex *ExpectedPrepare + query string +} + +func (stmt *statement) Close() error { + stmt.ex.wasClosed = true + return stmt.ex.closeErr +} + +func (stmt *statement) NumInput() int { + return -1 +} + +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.Exec(stmt.query, args) +} + +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.Query(stmt.query, args) +} diff --git a/vendor/github.com/pingcap/tidb-tools/tidb-binlog/driver/reader/offset.go b/vendor/github.com/pingcap/tidb-tools/tidb-binlog/driver/reader/offset.go index b8d7bb20a..0e3d4a19d 100644 --- a/vendor/github.com/pingcap/tidb-tools/tidb-binlog/driver/reader/offset.go +++ b/vendor/github.com/pingcap/tidb-tools/tidb-binlog/driver/reader/offset.go @@ -148,7 +148,7 @@ func (ks *KafkaSeeker) seekOffset(topic string, partition int32, start int64, en } if endTS <= ts { - return sarama.OffsetNewest, nil + return end + 1, nil } return end, nil diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..533438d91 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "sync" + + "golang.org/x/net/context" +) + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + errOnce sync.Once + err error +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// +// The first call to return a non-nil error cancels the group; its error will be +// returned by Wait. +func (g *Group) Go(f func() error) { + g.wg.Add(1) + + go func() { + defer g.wg.Done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +}