Skip to content

Commit

Permalink
workload: fix Query for prepare method
Browse files Browse the repository at this point in the history
There was a bug because the connection would be released back to the
pool before the rows were read.

The only reason the connection was being borrowed explicitly was to
prepare statements, so now the statements have been changed to be
prepared as part of the BeforeAcquire callback instead.

Release justification: test only change
Release note: None
  • Loading branch information
rafiss committed Sep 1, 2021
1 parent 5e9110d commit 60dd572
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 76 deletions.
1 change: 1 addition & 0 deletions pkg/workload/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ go_library(
"//pkg/sql/types",
"//pkg/util/bufalloc",
"//pkg/util/encoding/csv",
"//pkg/util/log",
"//pkg/util/timeutil",
"//pkg/workload/histogram",
"@com_github_cockroachdb_errors//:errors",
Expand Down
41 changes: 37 additions & 4 deletions pkg/workload/pgx_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"context"
"sync/atomic"

"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"golang.org/x/sync/errgroup"
)
Expand All @@ -23,6 +25,9 @@ type MultiConnPool struct {
Pools []*pgxpool.Pool
// Atomic counter used by Get().
counter uint32
// preparedStatements is a map from name to SQL. The statements in the map
// are prepared whenever a new connection is acquired from the pool.
preparedStatements map[string]string
}

// MultiConnPoolCfg encapsulates the knobs passed to NewMultiConnPool.
Expand All @@ -39,6 +44,18 @@ type MultiConnPoolCfg struct {
MaxConnsPerPool int
}

// pgxLogger implements the pgx.Logger interface.
type pgxLogger struct{}

var _ pgx.Logger = pgxLogger{}

// Log implements the pgx.Logger interface.
func (p pgxLogger) Log(
ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{},
) {
log.Infof(ctx, "pgx logger [%s]: %s logParams=%v", level.String(), msg, data)
}

// NewMultiConnPool creates a new MultiConnPool.
//
// Each URL gets one or more pools, and each pool has at most MaxConnsPerPool
Expand All @@ -49,7 +66,9 @@ type MultiConnPoolCfg struct {
func NewMultiConnPool(
ctx context.Context, cfg MultiConnPoolCfg, urls ...string,
) (*MultiConnPool, error) {
m := &MultiConnPool{}
m := &MultiConnPool{
preparedStatements: map[string]string{},
}
connsPerURL := distribute(cfg.MaxTotalConnections, len(urls))
maxConnsPerPool := cfg.MaxConnsPerPool
if maxConnsPerPool == 0 {
Expand All @@ -61,13 +80,27 @@ func NewMultiConnPool(
connsPerPool := distributeMax(connsPerURL[i], maxConnsPerPool)
for _, numConns := range connsPerPool {
connCfg, err := pgxpool.ParseConfig(urls[i])
// Disable the automatic prepared statement cache. We've seen a lot of
// churn in this cache since workloads create many of different queries.
connCfg.ConnConfig.BuildStatementCache = nil
if err != nil {
return nil, err
}
// Disable the automatic prepared statement cache. We've seen a lot of
// churn in this cache since workloads create many of different queries.
connCfg.ConnConfig.BuildStatementCache = nil
connCfg.ConnConfig.LogLevel = pgx.LogLevelWarn
connCfg.ConnConfig.Logger = pgxLogger{}
connCfg.MaxConns = int32(numConns)
connCfg.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
for name, sql := range m.preparedStatements {
// Note that calling `Prepare` with a name that has already been
// prepared is idempotent and short-circuits before doing any
// communication to the server.
if _, err := conn.Prepare(ctx, name, sql); err != nil {
log.Warningf(ctx, "error preparing statement. name=%s sql=%s %v", name, sql, err)
return false
}
}
return true
}
p, err := pgxpool.ConnectConfig(ctx, connCfg)
if err != nil {
return nil, err
Expand Down
76 changes: 4 additions & 72 deletions pkg/workload/sql_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)

// SQLRunner is a helper for issuing SQL statements; it supports multiple
Expand Down Expand Up @@ -120,6 +119,7 @@ func (sr *SQLRunner) Init(
for i, s := range sr.stmts {
stmtName := fmt.Sprintf("%s-%d", name, i+1)
s.preparedName = stmtName
mcp.preparedStatements[stmtName] = s.sql
}
}

Expand Down Expand Up @@ -156,19 +156,7 @@ func (h StmtHandle) Exec(ctx context.Context, args ...interface{}) (pgconn.Comma
p := h.s.sr.mcp.Get()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
var commandTag pgconn.CommandTag
err := p.AcquireFunc(ctx, func(conn *pgxpool.Conn) error {
if _, err := conn.Conn().Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
return err
}
var connErr error
commandTag, connErr = conn.Conn().Exec(ctx, h.s.preparedName, args...)
return connErr
})
return commandTag, err
return p.Exec(ctx, h.s.preparedName, args...)

case noprepare:
return p.Exec(ctx, h.s.sql, args...)
Expand All @@ -191,12 +179,6 @@ func (h StmtHandle) ExecTx(
h.check()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
if _, err := tx.Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
return nil, err
}
return tx.Exec(ctx, h.s.preparedName, args...)

case noprepare:
Expand All @@ -219,19 +201,7 @@ func (h StmtHandle) Query(ctx context.Context, args ...interface{}) (pgx.Rows, e
p := h.s.sr.mcp.Get()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
var rows pgx.Rows
err := p.AcquireFunc(ctx, func(conn *pgxpool.Conn) error {
if _, err := conn.Conn().Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
return err
}
var connErr error
rows, connErr = conn.Conn().Query(ctx, h.s.preparedName, args...)
return connErr
})
return rows, err
return p.Query(ctx, h.s.preparedName, args...)

case noprepare:
return p.Query(ctx, h.s.sql, args...)
Expand All @@ -252,12 +222,6 @@ func (h StmtHandle) QueryTx(ctx context.Context, tx pgx.Tx, args ...interface{})
h.check()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
if _, err := tx.Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
return nil, err
}
return tx.Query(ctx, h.s.preparedName, args...)

case noprepare:
Expand All @@ -280,22 +244,7 @@ func (h StmtHandle) QueryRow(ctx context.Context, args ...interface{}) pgx.Row {
p := h.s.sr.mcp.Get()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
var row pgx.Row
err := p.AcquireFunc(ctx, func(conn *pgxpool.Conn) error {
if _, err := conn.Conn().Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
return err
}
row = conn.Conn().QueryRow(ctx, h.s.preparedName, args...)
return nil
})
if err != nil {
r := errRow{retErr: err}
return &r
}
return row
return p.QueryRow(ctx, h.s.preparedName, args...)

case noprepare:
return p.QueryRow(ctx, h.s.sql, args...)
Expand All @@ -317,13 +266,6 @@ func (h StmtHandle) QueryRowTx(ctx context.Context, tx pgx.Tx, args ...interface
h.check()
switch h.s.sr.method {
case prepare:
// Note that calling `Prepare` with a name that has already been prepared
// is idempotent and short-circuits before doing any communication to the
// server.
if _, err := tx.Prepare(ctx, h.s.preparedName, h.s.sql); err != nil {
r := errRow{retErr: err}
return &r
}
return tx.QueryRow(ctx, h.s.preparedName, args...)

case noprepare:
Expand All @@ -338,15 +280,5 @@ func (h StmtHandle) QueryRowTx(ctx context.Context, tx pgx.Tx, args ...interface
}
}

// errRow implements the pgx.Row interface. It's used only in the `prepare`
// mode.
type errRow struct {
retErr error
}

var _ pgx.Row = &errRow{}

func (r *errRow) Scan(_ ...interface{}) error { return r.retErr }

// Appease the linter.
var _ = StmtHandle.QueryRow

0 comments on commit 60dd572

Please sign in to comment.