Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
r0mant committed Apr 9, 2021
1 parent 8c12291 commit c35a718
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 63 deletions.
2 changes: 2 additions & 0 deletions api/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ const (

// UseOfClosedNetworkConnection is a special string some parts of
// go standard lib are using that is the only way to identify some errors
//
// TODO(r0mant): See if we can use net.ErrClosed and errors.Is() instead.
UseOfClosedNetworkConnection = "use of closed network connection"
)

Expand Down
78 changes: 27 additions & 51 deletions lib/srv/db/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,52 +50,27 @@ func TestAuditPostgres(t *testing.T) {
// Access denied should trigger an unsuccessful session start event.
_, err = testCtx.postgresClient(ctx, "alice", "notpostgres", "notpostgres")
require.Error(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionStartFailureCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartFailureCode)
}
requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode)

// Connect should trigger successful session start event.
psql, err := testCtx.postgresClient(ctx, "alice", "postgres", "postgres")
require.NoError(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionStartCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartCode)
}
requireEvent(t, testCtx, libevents.DatabaseSessionStartCode)

// Simple query should trigger the query event.
_, err = psql.Exec(ctx, "select 1").ReadAll()
require.NoError(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode)
}
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1")

// Prepared statement execution should also trigger a query event.
result := psql.ExecParams(ctx, "select 1", nil, nil, nil, nil).Read()
result := psql.ExecParams(ctx, "select now()", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode)
}
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select now()")

// Closing connection should trigger session end event.
err = psql.Close(ctx)
require.NoError(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionEndCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionEndCode)
}
requireEvent(t, testCtx, libevents.DatabaseSessionEndCode)
}

// TestAuditMySQL verifies proper audit events are emitted for MySQL
Expand All @@ -116,42 +91,43 @@ func TestAuditMySQL(t *testing.T) {
// Access denied should trigger an unsuccessful session start event.
_, err = testCtx.mysqlClient("alice", "notroot")
require.Error(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionStartFailureCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartFailureCode)
}
requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode)

// Connect should trigger successful session start event.
mysql, err := testCtx.mysqlClient("alice", "root")
require.NoError(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionStartCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartCode)
}
requireEvent(t, testCtx, libevents.DatabaseSessionStartCode)

// Simple query should trigger the query event.
_, err = mysql.Execute("select 1")
require.NoError(t, err)
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode())
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode)
}
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1")

// Closing connection should trigger session end event.
err = mysql.Close()
require.NoError(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionEndCode)
}

func requireEvent(t *testing.T, testCtx *testContext, code string) {
event := waitForEvent(t, testCtx, code)
require.Equal(t, code, event.GetCode())
}

func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) {
event := waitForEvent(t, testCtx, code)
require.Equal(t, code, event.GetCode())
require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery)
}

func waitForEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent {
select {
case event := <-testCtx.emitter.eventsCh:
require.Equal(t, libevents.DatabaseSessionEndCode, event.GetCode())
return event
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionEndCode)
t.Fatalf("didn't receive %v event after 1 second", code)
}
return nil
}

// testEmitter pushes all received audit events into a channel.
Expand Down
3 changes: 0 additions & 3 deletions lib/srv/db/common/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ type audit struct {
log logrus.FieldLogger
}

// NewAuditFn defines a function that creates an audit logger.
type NewAuditFn func(AuditConfig) (Audit, error)

// NewAudit returns a new instance of the audit events emitter.
func NewAudit(config AuditConfig) (Audit, error) {
if err := config.Check(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/common/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Session struct {
// Log is the logger with session specific fields.
Log logrus.FieldLogger
// Statements is the session's prepared statements cache.
Statements StatementsCache
Statements *StatementsCache
}

// String returns string representation of the session parameters.
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/common/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ type Portal struct {
}

// NewStatementsCache returns a new instance of prepared statements cache.
func NewStatementsCache() StatementsCache {
return StatementsCache{cache: make(map[string]Statement)}
func NewStatementsCache() *StatementsCache {
return &StatementsCache{cache: make(map[string]Statement)}
}

// Save adds the provided prepared statement information to the cache.
Expand Down
15 changes: 10 additions & 5 deletions lib/srv/db/postgres/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ func (e *Engine) receiveFromClient(client *pgproto3.Backend, server *pgproto3.Fr
// https://www.postgresql.org/docs/10/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
sessionCtx.Statements.Save(msg.Name, msg.Query)
case *pgproto3.Bind:
// Bind message readies an existing prepared statement (created when
// Parse message is received), for execution into what Postgres
// Bind message readies existing prepared statement (created when
// Parse message is received) for execution into what Postgres
// calls a "destination portal", optionally binding it with
// parameters (for parameterized queries).
err := sessionCtx.Statements.Bind(
Expand Down Expand Up @@ -406,15 +406,20 @@ func getBindParameters(msg *pgproto3.Bind) (parameters []string) {
// According to Bind message documentation, if there are no parameter
// format codes, it may mean that either there are no parameters, or
// that all parameters use default text format.
if len(msg.ParameterFormatCodes) == 0 || msg.ParameterFormatCodes[i] == parameterFormatCodeText {
if len(msg.ParameterFormatCodes) == 0 {
parameters = append(parameters, string(p))
continue
}
switch msg.ParameterFormatCodes[i] {
case parameterFormatCodeText:
// Text parameters can just be converted to their string
// representation.
parameters = append(parameters, string(p))
} else if msg.ParameterFormatCodes[i] == parameterFormatCodeBinary {
case parameterFormatCodeBinary:
// For binary parameters, just put a placeholder to avoid
// spamming the audit log with unreadable info.
parameters = append(parameters, "<binary>")
} else {
default:
// Should never happen but...
logrus.Warnf("Unknown Postgres parameter format code: %#v.", msg)
parameters = append(parameters, "<unknown>")
Expand Down
5 changes: 4 additions & 1 deletion lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Config struct {
// StreamEmitter is a non-blocking audit events emitter.
StreamEmitter events.StreamEmitter
// NewAudit allows to override audit logger in tests.
NewAudit common.NewAuditFn
NewAudit NewAuditFn
// TLSConfig is the *tls.Config for this server.
TLSConfig *tls.Config
// Authorizer is used to authorize requests coming from proxy.
Expand All @@ -75,6 +75,9 @@ type Config struct {
OnHeartbeat func(error)
}

// NewAuditFn defines a function that creates an audit logger.
type NewAuditFn func(common.AuditConfig) (common.Audit, error)

// CheckAndSetDefaults makes sure the configuration has the minimum required
// to function.
func (c *Config) CheckAndSetDefaults(ctx context.Context) error {
Expand Down

0 comments on commit c35a718

Please sign in to comment.