Skip to content

Commit

Permalink
all: add support for cockroachdb
Browse files Browse the repository at this point in the history
add support for cockroachdb

Signed-off-by: David López <[email protected]>
  • Loading branch information
lopezator committed Apr 17, 2019
1 parent 5a284d6 commit 2ec34ea
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 50 deletions.
5 changes: 4 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ jobs:
- image: circleci/golang:1.11.5
environment:
- TEST_DATABASE_POSTGRESQL=postgres://test:test@localhost:5432/sqlcon?sslmode=disable
- TEST_DATABASE_MYSQL=root:test@(localhost:3306)/mysql?parseTime=true
- TEST_DATABASE_MYSQL=mysql://root:test@(localhost:3306)/mysql?parseTime=true
- TEST_DATABASE_COCKROACHDB=cockroach://root@localhost:26257/defaultdb?sslmode=disable
- image: postgres:9.5
environment:
- POSTGRES_USER=test
Expand All @@ -17,6 +18,8 @@ jobs:
- image: mysql:5.7
environment:
- MYSQL_ROOT_PASSWORD=test
- image: cockroachdb/cockroach:v2.1.6
command: start --insecure
working_directory: /go/src/github.com/ory/sqlcon
steps:
- checkout
Expand Down
11 changes: 8 additions & 3 deletions dbal/canonicalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,31 @@ const (
// DriverMySQL is the mysql driver name.
DriverMySQL = "mysql"

// DriverPostgreSQL is the mysql driver name.
// DriverPostgreSQL is the postgres driver name.
DriverPostgreSQL = "postgres"

// DriverCockroachDB is the cockroach driver name.
DriverCockroachDB = "cockroach"

// UnknownDriver is the driver name if the driver is unknown.
UnknownDriver = "unknown"
)

// Canonicalize returns constants DriverMySQL, DriverPostgreSQL, UnknownDriver, depending on `database`.
// Canonicalize returns constants DriverMySQL, DriverPostgreSQL, DriverCockroachDB, UnknownDriver, depending on `database`.
func Canonicalize(database string) string {
switch database {
case "mysql":
return DriverMySQL
case "pgx", "pq", "postgres":
return DriverPostgreSQL
case "cockroach":
return DriverCockroachDB
default:
return UnknownDriver
}
}

// MustCanonicalize returns constants DriverMySQL, DriverPostgreSQL or fatals.
// MustCanonicalize returns constants DriverMySQL, DriverPostgreSQL, DriverCockroachDB or fatals.
func MustCanonicalize(database string) string {
d := Canonicalize(database)
if d == UnknownDriver {
Expand Down
5 changes: 3 additions & 2 deletions dbal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"net/url"
"time"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"

"github.com/ory/x/sqlcon"
)

// Connect is a wrapper for connecting to different SQL drivers.
func Connect(db string, logger logrus.FieldLogger, memf func() error, sqlf func(db *sqlx.DB) error) error {
func Connect(db string, logger logrus.FieldLogger, memf func() error, sqlf func(db *sqlcon.DB) error) error {
if db == "memory" {
return memf()
} else if db == "" {
Expand All @@ -27,6 +26,8 @@ func Connect(db string, logger logrus.FieldLogger, memf func() error, sqlf func(
switch u.Scheme {
case "postgres":
fallthrough
case "cockroach":
fallthrough
case "mysql":
c, err := sqlcon.NewSQLConnection(db, logger)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion dbal/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewPackerMigrationSource(l logrus.FieldLogger, sources []string, loader fun

var found bool
for _, f := range filters {
if strings.Contains(source, f) {
if filepath.Dir(source) == f {
found = true
}
}
Expand Down
30 changes: 19 additions & 11 deletions dbal/migratest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import (
"sync"
"testing"

"github.com/ory/x/dbal"

"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
migrate "github.com/rubenv/sql-migrate"
"github.com/stretchr/testify/require"

"github.com/ory/x/dbal"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlcon/dockertest"
)

Expand All @@ -21,16 +20,16 @@ type MigrationSchemas []map[string]*dbal.PackrMigrationSource
// RunPackrMigrationTests runs migration tests from packr migrations.
func RunPackrMigrationTests(
t *testing.T, schema, data MigrationSchemas,
init, cleanup func(*testing.T, *sqlx.DB),
runner func(*testing.T, *sqlx.DB, int, int, int),
init, cleanup func(*testing.T, *sqlcon.DB),
runner func(*testing.T, *sqlcon.DB, int, int, int),
) {
if testing.Short() {
t.SkipNow()
return
}

var m sync.Mutex
var dbs = map[string]*sqlx.DB{}
var dbs = map[string]*sqlcon.DB{}
var mid = uuid.New()

dockertest.Parallel([]func(){
Expand All @@ -40,7 +39,7 @@ func RunPackrMigrationTests(
t.Fatalf("Could not connect to database: %v", err)
}
m.Lock()
dbs["postgres"] = db
dbs["postgres"] = sqlcon.NewDB(db, "postgres")
m.Unlock()
},
func() {
Expand All @@ -49,7 +48,16 @@ func RunPackrMigrationTests(
t.Fatalf("Could not connect to database: %v", err)
}
m.Lock()
dbs["mysql"] = db
dbs["mysql"] = sqlcon.NewDB(db, "mysql")
m.Unlock()
},
func() {
db, err := dockertest.ConnectToTestCockroachDB()
if err != nil {
t.Fatalf("Could not connect to database: %v", err)
}
m.Lock()
dbs["cockroach"] = sqlcon.NewDB(db, "cockroach")
m.Unlock()
},
})
Expand All @@ -68,7 +76,7 @@ func RunPackrMigrationTests(
for step := 0; step < steps; step++ {
t.Run(fmt.Sprintf("up=%d", step), func(t *testing.T) {
migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
n, err := migrate.ExecMax(db.DB, db.DriverName(), ss[name], migrate.Up, 1)
n, err := migrate.ExecMax(db.DB.DB, db.Dialect(), ss[name], migrate.Up, 1)
require.NoError(t, err)
require.Equal(t, n, 1, sk)

Expand All @@ -79,7 +87,7 @@ func RunPackrMigrationTests(
}

migrate.SetTable(fmt.Sprintf("%s_%d_data", mid, sk))
n, err = migrate.ExecMax(db.DB, db.DriverName(), data[sk][name], migrate.Up, 1)
n, err = migrate.ExecMax(db.DB.DB, db.Dialect(), data[sk][name], migrate.Up, 1)
require.NoError(t, err)
require.Equal(t, n, 1)
})
Expand All @@ -103,7 +111,7 @@ func RunPackrMigrationTests(
migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
for step := 0; step < steps; step++ {
t.Run(fmt.Sprintf("down=%d", step), func(t *testing.T) {
n, err := migrate.ExecMax(db.DB, db.DriverName(), ss[name], migrate.Down, 1)
n, err := migrate.ExecMax(db.DB.DB, db.Dialect(), ss[name], migrate.Down, 1)
require.NoError(t, err)
require.Equal(t, n, 1)
})
Expand Down
63 changes: 59 additions & 4 deletions sqlcon/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package sqlcon

import (
"context"
"database/sql"
"fmt"
"net/url"
Expand All @@ -43,12 +44,55 @@ import (

// SQLConnection represents a connection to a SQL database.
type SQLConnection struct {
db *sqlx.DB
db *DB
URL *url.URL
L logrus.FieldLogger
options
}

// DB represents a wrapped sqlx.DB with own defined driver name.
type DB struct {
*sqlx.DB
driverName string
bindType int
}

// NewDB returns a new DB
func NewDB(sqlxDB *sqlx.DB, driverName string) *DB {
db := &DB{DB: sqlxDB, driverName: driverName}
if driverName != "cockroach" {
db.bindType = sqlx.BindType(driverName)
} else {
db.bindType = sqlx.DOLLAR
}
return db
}

// Rebind wraps sqlx.DB.Rebind
func (d *DB) Rebind(query string) string {
return sqlx.Rebind(d.bindType, query)
}

// NamedExecContext wraps sqlx.DB.NamedExecContext
func (d *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return d.DB.NamedExecContext(ctx, d.Rebind(query), arg)
}

// DriverName returns db.driverName
func (d *DB) DriverName() string {
return d.driverName
}

// Dialect returns sql.DB.DriverName
func (d *DB) Dialect() string {
dialect := d.DB.DriverName()
switch dialect {
case "pgx", "pq":
dialect = "postgres"
}
return dialect
}

// NewSQLConnection returns a new SQLConnection.
func NewSQLConnection(db string, l logrus.FieldLogger, opts ...OptionModifier) (*SQLConnection, error) {
u, err := url.Parse(db)
Expand Down Expand Up @@ -92,7 +136,7 @@ func cleanURLQuery(c *url.URL) *url.URL {
}

// GetDatabaseRetry tries to connect to a database and fails after failAfter.
func (c *SQLConnection) GetDatabaseRetry(maxWait time.Duration, failAfter time.Duration) (*sqlx.DB, error) {
func (c *SQLConnection) GetDatabaseRetry(maxWait time.Duration, failAfter time.Duration) (*DB, error) {
if err := resilience.Retry(c.L, maxWait, failAfter, func() (err error) {
c.db, err = c.GetDatabase()
if err != nil {
Expand All @@ -107,7 +151,7 @@ func (c *SQLConnection) GetDatabaseRetry(maxWait time.Duration, failAfter time.D
}

// GetDatabase retrusn a database instance.
func (c *SQLConnection) GetDatabase() (*sqlx.DB, error) {
func (c *SQLConnection) GetDatabase() (*DB, error) {
if c.db != nil {
return c.db, nil
}
Expand All @@ -123,12 +167,15 @@ func (c *SQLConnection) GetDatabase() (*sqlx.DB, error) {
c.L.Infof("Connecting with %s", c.URL.Scheme+"://*:*@"+c.URL.Host+c.URL.Path+"?"+clean.RawQuery)
u := connectionString(clean)

if registeredDriver == "cockroach" {
registeredDriver = "postgres"
}
db, err := sql.Open(registeredDriver, u)
if err != nil {
return nil, errors.Wrapf(err, "could not open SQL connection")
}

c.db = sqlx.NewDb(db, clean.Scheme)
c.db = NewDB(sqlx.NewDb(db, registeredDriver), clean.Scheme)
if err := c.db.Ping(); err != nil {
return nil, errors.Wrapf(err, "could not ping SQL connection")
}
Expand Down Expand Up @@ -204,6 +251,9 @@ func connectionString(clean *url.URL) string {
if clean.Scheme == "mysql" {
u = strings.Replace(u, "mysql://", "", -1)
}
if clean.Scheme == "cockroach" {
u = strings.Replace(u, "cockroach://", "postgres://", 1)
}
return u
}

Expand All @@ -229,6 +279,11 @@ func (c *SQLConnection) registerDriver() (string, error) {
// and does not satisfy the driver.Driver interface.
sql.Register(driverName,
instrumentedsql.WrapDriver(&pq.Driver{}, tracingOpts...))
case "cockroach":
// Why does this have to be a pointer? Because the Open method for postgres has a pointer receiver
// and does not satisfy the driver.Driver interface.
sql.Register(driverName,
instrumentedsql.WrapDriver(&pq.Driver{}, tracingOpts...))
default:
return "", fmt.Errorf("unsupported scheme (%s) in DSN", c.URL.Scheme)
}
Expand Down
Loading

0 comments on commit 2ec34ea

Please sign in to comment.