Skip to content

Commit

Permalink
all: add support for cockroachdb
Browse files Browse the repository at this point in the history
add support for cockroachdb

Signed-off-by: David López <[email protected]>
  • Loading branch information
lopezator committed Apr 30, 2019
1 parent c3f1077 commit 3190842
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 39 deletions.
5 changes: 4 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ jobs:
- image: circleci/golang:1.11.5
environment:
- TEST_DATABASE_POSTGRESQL=postgres://test:test@localhost:5432/sqlcon?sslmode=disable
- TEST_DATABASE_MYSQL=root:test@(localhost:3306)/mysql?parseTime=true
- TEST_DATABASE_MYSQL=mysql://root:test@(localhost:3306)/mysql?parseTime=true
- TEST_DATABASE_COCKROACHDB=cockroach://root@localhost:26257/defaultdb?sslmode=disable
- image: postgres:9.5
environment:
- POSTGRES_USER=test
Expand All @@ -17,6 +18,8 @@ jobs:
- image: mysql:5.7
environment:
- MYSQL_ROOT_PASSWORD=test
- image: cockroachdb/cockroach:v2.1.6
command: start --insecure
working_directory: /go/src/github.com/ory/sqlcon
steps:
- checkout
Expand Down
11 changes: 8 additions & 3 deletions dbal/canonicalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,31 @@ const (
// DriverMySQL is the mysql driver name.
DriverMySQL = "mysql"

// DriverPostgreSQL is the mysql driver name.
// DriverPostgreSQL is the postgres driver name.
DriverPostgreSQL = "postgres"

// DriverCockroachDB is the cockroach driver name.
DriverCockroachDB = "cockroach"

// UnknownDriver is the driver name if the driver is unknown.
UnknownDriver = "unknown"
)

// Canonicalize returns constants DriverMySQL, DriverPostgreSQL, UnknownDriver, depending on `database`.
// Canonicalize returns constants DriverMySQL, DriverPostgreSQL, DriverCockroachDB, UnknownDriver, depending on `database`.
func Canonicalize(database string) string {
switch database {
case "mysql":
return DriverMySQL
case "pgx", "pq", "postgres":
return DriverPostgreSQL
case "cockroach":
return DriverCockroachDB
default:
return UnknownDriver
}
}

// MustCanonicalize returns constants DriverMySQL, DriverPostgreSQL or fatals.
// MustCanonicalize returns constants DriverMySQL, DriverPostgreSQL, DriverCockroachDB or fatals.
func MustCanonicalize(database string) string {
d := Canonicalize(database)
if d == UnknownDriver {
Expand Down
2 changes: 2 additions & 0 deletions dbal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func Connect(db string, logger logrus.FieldLogger, memf func() error, sqlf func(
switch u.Scheme {
case "postgres":
fallthrough
case "cockroach":
fallthrough
case "mysql":
c, err := sqlcon.NewSQLConnection(db, logger)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion dbal/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewPackerMigrationSource(l logrus.FieldLogger, sources []string, loader fun

var found bool
for _, f := range filters {
if strings.Contains(source, f) {
if filepath.Dir(source) == f {
found = true
}
}
Expand Down
23 changes: 16 additions & 7 deletions dbal/migratest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import (
"sync"
"testing"

"github.com/ory/x/dbal"

"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
migrate "github.com/rubenv/sql-migrate"
"github.com/stretchr/testify/require"

"github.com/ory/x/dbal"
"github.com/ory/x/sqlcon/dockertest"
)

Expand All @@ -22,7 +21,7 @@ type MigrationSchemas []map[string]*dbal.PackrMigrationSource
func RunPackrMigrationTests(
t *testing.T, schema, data MigrationSchemas,
init, cleanup func(*testing.T, *sqlx.DB),
runner func(*testing.T, *sqlx.DB, int, int, int),
runner func(*testing.T, string, *sqlx.DB, int, int, int),
) {
if testing.Short() {
t.SkipNow()
Expand Down Expand Up @@ -52,13 +51,23 @@ func RunPackrMigrationTests(
dbs["mysql"] = db
m.Unlock()
},
func() {
db, err := dockertest.ConnectToTestCockroachDB()
if err != nil {
t.Fatalf("Could not connect to database: %v", err)
}
m.Lock()
dbs["cockroach"] = db
m.Unlock()
},
})

if data != nil {
require.Equal(t, len(schema), len(data))
}

for name, db := range dbs {
dialect := db.DriverName()
t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) {
init(t, db)

Expand All @@ -68,7 +77,7 @@ func RunPackrMigrationTests(
for step := 0; step < steps; step++ {
t.Run(fmt.Sprintf("up=%d", step), func(t *testing.T) {
migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
n, err := migrate.ExecMax(db.DB, db.DriverName(), ss[name], migrate.Up, 1)
n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Up, 1)
require.NoError(t, err)
require.Equal(t, n, 1, sk)

Expand All @@ -79,7 +88,7 @@ func RunPackrMigrationTests(
}

migrate.SetTable(fmt.Sprintf("%s_%d_data", mid, sk))
n, err = migrate.ExecMax(db.DB, db.DriverName(), data[sk][name], migrate.Up, 1)
n, err = migrate.ExecMax(db.DB, dialect, data[sk][name], migrate.Up, 1)
require.NoError(t, err)
require.Equal(t, n, 1)
})
Expand All @@ -88,7 +97,7 @@ func RunPackrMigrationTests(

for step := 0; step < steps; step++ {
t.Run(fmt.Sprintf("runner=%d", step), func(t *testing.T) {
runner(t, db, sk, step, steps)
runner(t, name, db, sk, step, steps)
})
}
})
Expand All @@ -103,7 +112,7 @@ func RunPackrMigrationTests(
migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
for step := 0; step < steps; step++ {
t.Run(fmt.Sprintf("down=%d", step), func(t *testing.T) {
n, err := migrate.ExecMax(db.DB, db.DriverName(), ss[name], migrate.Down, 1)
n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Down, 1)
require.NoError(t, err)
require.Equal(t, n, 1)
})
Expand Down
10 changes: 9 additions & 1 deletion sqlcon/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,15 @@ func (c *SQLConnection) GetDatabase() (*sqlx.DB, error) {
c.L.Infof("Connecting with %s", c.URL.Scheme+"://*:*@"+c.URL.Host+c.URL.Path+"?"+clean.RawQuery)
u := connectionString(clean)

if registeredDriver == "cockroach" {
registeredDriver = "postgres"
}
db, err := sql.Open(registeredDriver, u)
if err != nil {
return nil, errors.Wrapf(err, "could not open SQL connection")
}

c.db = sqlx.NewDb(db, clean.Scheme)
c.db = sqlx.NewDb(db, registeredDriver)
if err := c.db.Ping(); err != nil {
return nil, errors.Wrapf(err, "could not ping SQL connection")
}
Expand Down Expand Up @@ -205,6 +208,9 @@ func connectionString(clean *url.URL) string {
if clean.Scheme == "mysql" {
u = strings.Replace(u, "mysql://", "", -1)
}
if clean.Scheme == "cockroach" {
u = strings.Replace(u, "cockroach://", "postgres://", 1)
}
return u
}

Expand All @@ -229,6 +235,8 @@ func (c *SQLConnection) registerDriver() (string, error) {
case "mysql":
sql.Register(driverName,
instrumentedsql.WrapDriver(mysql.MySQLDriver{}, tracingOpts...))
case "cockroach":
fallthrough
case "postgres":
// Why does this have to be a pointer? Because the Open method for postgres has a pointer receiver
// and does not satisfy the driver.Driver interface.
Expand Down
76 changes: 60 additions & 16 deletions sqlcon/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/opentracing/opentracing-go"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
Expand All @@ -45,9 +45,10 @@ import (
)

var (
mysqlURL *url.URL
postgresURL *url.URL
resources []*dockertest.Resource
mysqlURL *url.URL
postgresURL *url.URL
cockroachURL *url.URL
resources []*dockertest.Resource
)

func TestMain(m *testing.M) {
Expand All @@ -56,6 +57,7 @@ func TestMain(m *testing.M) {
dockertestd.Parallel([]func(){
bootstrapMySQL,
bootstrapPostgres,
bootstrapCockroach,
})
}

Expand All @@ -79,8 +81,9 @@ func TestDistributedTracing(t *testing.T) {
}

databases := map[string]string{
"mysql": mysqlURL.String(),
"postgres": postgresURL.String(),
"mysql": mysqlURL.String(),
"postgres": postgresURL.String(),
"cockroach": cockroachURL.String(),
}

for driver, dsn := range databases {
Expand Down Expand Up @@ -144,9 +147,6 @@ func TestDistributedTracing(t *testing.T) {
}

func TestRegisterDriver(t *testing.T) {
unsupportedDSN := "unsupported://unsupported:secret@localhost:1337/mydb"
supportedDSN := "mysql://foo@bar:baz@qux/db"

for _, testCase := range []struct {
description string
sqlConnection *SQLConnection
Expand All @@ -155,20 +155,20 @@ func TestRegisterDriver(t *testing.T) {
}{
{
description: "should return error if supplied DSN is unsupported for tracing",
sqlConnection: mustSQL(t, unsupportedDSN, WithDistributedTracing()),
sqlConnection: mustSQL(t, "unsupported://unsupported:secret@localhost:1337/mydb", WithDistributedTracing()),
expectedDriverName: "",
shouldError: true,
},
{
description: "should return registered driver name if supplied DSN is valid for tracing",
sqlConnection: mustSQL(t, supportedDSN, WithDistributedTracing()),
sqlConnection: mustSQL(t, "mysql://foo@bar:baz@qux/db", WithDistributedTracing()),
expectedDriverName: "instrumented-sql-driver",
shouldError: false,
},
{
description: "should return registered driver name if tracing is NOT configured",
sqlConnection: mustSQL(t, supportedDSN),
expectedDriverName: "mysql",
description: "should return cockroach driver if a valid cockroach DSN is supplied",
sqlConnection: mustSQL(t, "cockroach://foo@bar:baz@qux/db"),
expectedDriverName: "cockroach",
shouldError: false,
},
} {
Expand Down Expand Up @@ -243,6 +243,18 @@ func TestSQLConnection(t *testing.T) {
d: "pg max_conn_lifetime",
s: mustSQL(t, merge(postgresURL, map[string]string{"max_conn_lifetime": "1h", "max_idle_conns": "10", "max_conns": "10"}).String()),
},
{
d: "crdb raw",
s: mustSQL(t, cockroachURL.String()),
},
{
d: "crdb max_conn_lifetime",
s: mustSQL(t, merge(cockroachURL, map[string]string{"max_conn_lifetime": "1h"}).String()),
},
{
d: "crdb max_conn_lifetime",
s: mustSQL(t, merge(cockroachURL, map[string]string{"max_conn_lifetime": "1h", "max_idle_conns": "10", "max_conns": "10"}).String()),
},
} {
t.Run(fmt.Sprintf("case=%s", tc.d), func(t *testing.T) {
tc.s.L = logrus.New()
Expand Down Expand Up @@ -276,11 +288,11 @@ func killAll() {
func bootstrapMySQL() {
if uu := os.Getenv("TEST_DATABASE_MYSQL"); uu != "" {
log.Println("Found mysql test database config, skipping dockertest...")
_, err := sqlx.Open("postgres", uu)
_, err := sqlx.Open("mysql", uu)
if err != nil {
log.Fatalf("Could not connect to bootstrapped database: %s", err)
}
u, _ := url.Parse("mysql://" + uu)
u, _ := url.Parse(uu)
mysqlURL = u
return
}
Expand Down Expand Up @@ -330,6 +342,38 @@ func bootstrapPostgres() {
postgresURL = u
}

func bootstrapCockroach() {
if uu := os.Getenv("TEST_DATABASE_COCKROACHDB"); uu != "" {
log.Println("Found cockroachdb test database config, skipping dockertest...")
_, err := sqlx.Open("postgres", uu)
if err != nil {
log.Fatalf("Could not connect to bootstrapped database: %s", err)
}
u, _ := url.Parse(uu)
cockroachURL = u
return
}

pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not Connect to docker: %s", err)
}

resource, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "cockroachdb/cockroach",
Tag: "v2.1.6",
Cmd: []string{"start --insecure"},
})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}

urls := bootstrap("postgres://root@localhost:%s/defaultdb?sslmode=disable", "26257/tcp", "postgres", pool, resource)
resources = append(resources, resource)
u, _ := url.Parse(strings.Replace(urls, "postgres://", "cockroach://", 1))
cockroachURL = u
}

func bootstrap(u, port, driver string, pool *dockertest.Pool, resource *dockertest.Resource) (urls string) {
if err := pool.Retry(func() error {
var err error
Expand Down
Loading

0 comments on commit 3190842

Please sign in to comment.