Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add context.Context everywhere #1132

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cassandra

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -52,7 +53,7 @@ type Cassandra struct {
config *Config
}

func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if len(config.KeyspaceName) == 0 {
Expand All @@ -76,14 +77,14 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
config: config,
}

if err := c.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(ctx); err != nil {
return nil, err
}

return c, nil
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
func (c *Cassandra) Open(ctx context.Context, url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
Expand Down Expand Up @@ -185,34 +186,34 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
}
}

return WithInstance(session, &Config{
return WithInstance(ctx, session, &Config{
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
MigrationsTable: u.Query().Get("x-migrations-table"),
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
})
}

func (c *Cassandra) Close() error {
func (c *Cassandra) Close(ctx context.Context) error {
c.session.Close()
return nil
}

func (c *Cassandra) Lock() error {
func (c *Cassandra) Lock(ctx context.Context) error {
if !c.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}

func (c *Cassandra) Unlock() error {
func (c *Cassandra) Unlock(ctx context.Context) error {
if !c.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}

func (c *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(ctx context.Context, migration io.Reader) error {
if c.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
Expand Down Expand Up @@ -243,7 +244,7 @@ func (c *Cassandra) Run(migration io.Reader) error {
return nil
}

func (c *Cassandra) SetVersion(version int, dirty bool) error {
func (c *Cassandra) SetVersion(ctx context.Context, version int, dirty bool) error {
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
Expand Down Expand Up @@ -273,7 +274,7 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
}

// Return current keyspace version
func (c *Cassandra) Version() (version int, dirty bool, err error) {
func (c *Cassandra) Version(ctx context.Context) (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
Expand All @@ -291,7 +292,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) {
}
}

func (c *Cassandra) Drop() error {
func (c *Cassandra) Drop(ctx context.Context) error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
iter := c.session.Query(query).Iter()
Expand All @@ -309,13 +310,13 @@ func (c *Cassandra) Drop() error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Cassandra type.
func (c *Cassandra) ensureVersionTable() (err error) {
if err = c.Lock(); err != nil {
func (c *Cassandra) ensureVersionTable(ctx context.Context) (err error) {
if err = c.Lock(ctx); err != nil {
return err
}

defer func() {
if e := c.Unlock(); e != nil {
if e := c.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand All @@ -328,7 +329,7 @@ func (c *Cassandra) ensureVersionTable() (err error) {
if err != nil {
return err
}
if _, _, err = c.Version(); err != nil {
if _, _, err = c.Version(ctx); err != nil {
return err
}
return nil
Expand Down
12 changes: 7 additions & 5 deletions database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ func Test(t *testing.T) {

func test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()
Expand All @@ -97,23 +98,24 @@ func test(t *testing.T) {

func testMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()

m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "testks", d)
if err != nil {
t.Fatal(err)
}
Expand Down
47 changes: 24 additions & 23 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"io"
Expand Down Expand Up @@ -40,7 +41,7 @@ func init() {
database.Register("clickhouse", &ClickHouse{})
}

func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, conn *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
Expand All @@ -54,7 +55,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

Expand All @@ -67,7 +68,7 @@ type ClickHouse struct {
isLocked atomic.Bool
}

func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
func (ch *ClickHouse) Open(ctx context.Context, dsn string) (database.Driver, error) {
purl, err := url.Parse(dsn)
if err != nil {
return nil, err
Expand Down Expand Up @@ -104,14 +105,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
},
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

return ch, nil
}

func (ch *ClickHouse) init() error {
func (ch *ClickHouse) init(ctx context.Context) error {
if len(ch.config.DatabaseName) == 0 {
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
return err
Expand All @@ -130,18 +131,18 @@ func (ch *ClickHouse) init() error {
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
}

return ch.ensureVersionTable()
return ch.ensureVersionTable(ctx)
}

func (ch *ClickHouse) Run(r io.Reader) error {
func (ch *ClickHouse) Run(ctx context.Context, r io.Reader) error {
if ch.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
tq := strings.TrimSpace(string(m))
if tq == "" {
return true
}
if _, e := ch.conn.Exec(string(m)); e != nil {
if _, e := ch.conn.ExecContext(ctx, string(m)); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
Expand All @@ -157,13 +158,13 @@ func (ch *ClickHouse) Run(r io.Reader) error {
return err
}

if _, err := ch.conn.Exec(string(migration)); err != nil {
if _, err := ch.conn.ExecContext(ctx, string(migration)); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
}

return nil
}
func (ch *ClickHouse) Version() (int, bool, error) {
func (ch *ClickHouse) Version(ctx context.Context) (int, bool, error) {
var (
version int
dirty uint8
Expand All @@ -178,22 +179,22 @@ func (ch *ClickHouse) Version() (int, bool, error) {
return version, dirty == 1, nil
}

func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
func (ch *ClickHouse) SetVersion(ctx context.Context, version int, dirty bool) error {
var (
bool = func(v bool) uint8 {
if v {
return 1
}
return 0
}
tx, err = ch.conn.Begin()
tx, err = ch.conn.BeginTx(ctx, nil)
)
if err != nil {
return err
}

query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
if _, err := tx.ExecContext(ctx, query, version, bool(dirty), time.Now().UnixNano()); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -203,13 +204,13 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the ClickHouse type.
func (ch *ClickHouse) ensureVersionTable() (err error) {
if err = ch.Lock(); err != nil {
func (ch *ClickHouse) ensureVersionTable(ctx context.Context) (err error) {
if err = ch.Lock(ctx); err != nil {
return err
}

defer func() {
if e := ch.Unlock(); e != nil {
if e := ch.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand Down Expand Up @@ -252,15 +253,15 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
query = fmt.Sprintf(`%s ORDER BY sequence`, query)
}

if _, err := ch.conn.Exec(query); err != nil {
if _, err := ch.conn.ExecContext(ctx, query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}

func (ch *ClickHouse) Drop() (err error) {
func (ch *ClickHouse) Drop(ctx context.Context) (err error) {
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
tables, err := ch.conn.Query(query)
tables, err := ch.conn.QueryContext(ctx, query)

if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -279,7 +280,7 @@ func (ch *ClickHouse) Drop() (err error) {

query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table)

if _, err := ch.conn.Exec(query); err != nil {
if _, err := ch.conn.ExecContext(ctx, query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand All @@ -290,21 +291,21 @@ func (ch *ClickHouse) Drop() (err error) {
return nil
}

func (ch *ClickHouse) Lock() error {
func (ch *ClickHouse) Lock(ctx context.Context) error {
if !ch.isLocked.CAS(false, true) {
return database.ErrLocked
}

return nil
}
func (ch *ClickHouse) Unlock() error {
func (ch *ClickHouse) Unlock(ctx context.Context) error {
if !ch.isLocked.CAS(true, false) {
return database.ErrNotLocked
}

return nil
}
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
func (ch *ClickHouse) Close(ctx context.Context) error { return ch.conn.Close() }

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
Expand Down
Loading
Loading