diff --git a/api/constants/constants.go b/api/constants/constants.go index 97551f2612cc7..3c418376b1627 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -94,6 +94,8 @@ const ( // UseOfClosedNetworkConnection is a special string some parts of // go standard lib are using that is the only way to identify some errors + // + // TODO(r0mant): See if we can use net.ErrClosed and errors.Is() instead. UseOfClosedNetworkConnection = "use of closed network connection" ) diff --git a/lib/srv/db/audit_test.go b/lib/srv/db/audit_test.go index 960fe2a790dd5..7a76ca91c66f7 100644 --- a/lib/srv/db/audit_test.go +++ b/lib/srv/db/audit_test.go @@ -50,52 +50,27 @@ func TestAuditPostgres(t *testing.T) { // Access denied should trigger an unsuccessful session start event. _, err = testCtx.postgresClient(ctx, "alice", "notpostgres", "notpostgres") require.Error(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionStartFailureCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartFailureCode) - } + requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode) // Connect should trigger successful session start event. psql, err := testCtx.postgresClient(ctx, "alice", "postgres", "postgres") require.NoError(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionStartCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartCode) - } + requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) // Simple query should trigger the query event. _, err = psql.Exec(ctx, "select 1").ReadAll() require.NoError(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode) - } + requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1") // Prepared statement execution should also trigger a query event. - result := psql.ExecParams(ctx, "select 1", nil, nil, nil, nil).Read() + result := psql.ExecParams(ctx, "select now()", nil, nil, nil, nil).Read() require.NoError(t, result.Err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode) - } + requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select now()") // Closing connection should trigger session end event. err = psql.Close(ctx) require.NoError(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionEndCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionEndCode) - } + requireEvent(t, testCtx, libevents.DatabaseSessionEndCode) } // TestAuditMySQL verifies proper audit events are emitted for MySQL @@ -116,42 +91,43 @@ func TestAuditMySQL(t *testing.T) { // Access denied should trigger an unsuccessful session start event. _, err = testCtx.mysqlClient("alice", "notroot") require.Error(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionStartFailureCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartFailureCode) - } + requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode) // Connect should trigger successful session start event. mysql, err := testCtx.mysqlClient("alice", "root") require.NoError(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionStartCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionStartCode) - } + requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) // Simple query should trigger the query event. _, err = mysql.Execute("select 1") require.NoError(t, err) - select { - case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionQueryCode, event.GetCode()) - case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionQueryCode) - } + requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1") // Closing connection should trigger session end event. err = mysql.Close() require.NoError(t, err) + requireEvent(t, testCtx, libevents.DatabaseSessionEndCode) +} + +func requireEvent(t *testing.T, testCtx *testContext, code string) { + event := waitForEvent(t, testCtx, code) + require.Equal(t, code, event.GetCode()) +} + +func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) { + event := waitForEvent(t, testCtx, code) + require.Equal(t, code, event.GetCode()) + require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery) +} + +func waitForEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent { select { case event := <-testCtx.emitter.eventsCh: - require.Equal(t, libevents.DatabaseSessionEndCode, event.GetCode()) + return event case <-time.After(time.Second): - t.Fatalf("didn't receive %v event after 1 second", libevents.DatabaseSessionEndCode) + t.Fatalf("didn't receive %v event after 1 second", code) } + return nil } // testEmitter pushes all received audit events into a channel. diff --git a/lib/srv/db/common/audit.go b/lib/srv/db/common/audit.go index 1f3055d188c6d..0c131d343244d 100644 --- a/lib/srv/db/common/audit.go +++ b/lib/srv/db/common/audit.go @@ -59,9 +59,6 @@ type audit struct { log logrus.FieldLogger } -// NewAuditFn defines a function that creates an audit logger. -type NewAuditFn func(AuditConfig) (Audit, error) - // NewAudit returns a new instance of the audit events emitter. func NewAudit(config AuditConfig) (Audit, error) { if err := config.Check(); err != nil { diff --git a/lib/srv/db/common/session.go b/lib/srv/db/common/session.go index 57bbd5b492fdc..ea2c691c104b5 100644 --- a/lib/srv/db/common/session.go +++ b/lib/srv/db/common/session.go @@ -47,7 +47,7 @@ type Session struct { // Log is the logger with session specific fields. Log logrus.FieldLogger // Statements is the session's prepared statements cache. - Statements StatementsCache + Statements *StatementsCache } // String returns string representation of the session parameters. diff --git a/lib/srv/db/common/statements.go b/lib/srv/db/common/statements.go index 8dbd0863944b3..e515934d3c074 100644 --- a/lib/srv/db/common/statements.go +++ b/lib/srv/db/common/statements.go @@ -66,8 +66,8 @@ type Portal struct { } // NewStatementsCache returns a new instance of prepared statements cache. -func NewStatementsCache() StatementsCache { - return StatementsCache{cache: make(map[string]Statement)} +func NewStatementsCache() *StatementsCache { + return &StatementsCache{cache: make(map[string]Statement)} } // Save adds the provided prepared statement information to the cache. diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index d113db66fc272..cb32c66b725dd 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -273,8 +273,8 @@ func (e *Engine) receiveFromClient(client *pgproto3.Backend, server *pgproto3.Fr // https://www.postgresql.org/docs/10/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY sessionCtx.Statements.Save(msg.Name, msg.Query) case *pgproto3.Bind: - // Bind message readies an existing prepared statement (created when - // Parse message is received), for execution into what Postgres + // Bind message readies existing prepared statement (created when + // Parse message is received) for execution into what Postgres // calls a "destination portal", optionally binding it with // parameters (for parameterized queries). err := sessionCtx.Statements.Bind( @@ -406,15 +406,20 @@ func getBindParameters(msg *pgproto3.Bind) (parameters []string) { // According to Bind message documentation, if there are no parameter // format codes, it may mean that either there are no parameters, or // that all parameters use default text format. - if len(msg.ParameterFormatCodes) == 0 || msg.ParameterFormatCodes[i] == parameterFormatCodeText { + if len(msg.ParameterFormatCodes) == 0 { + parameters = append(parameters, string(p)) + continue + } + switch msg.ParameterFormatCodes[i] { + case parameterFormatCodeText: // Text parameters can just be converted to their string // representation. parameters = append(parameters, string(p)) - } else if msg.ParameterFormatCodes[i] == parameterFormatCodeBinary { + case parameterFormatCodeBinary: // For binary parameters, just put a placeholder to avoid // spamming the audit log with unreadable info. parameters = append(parameters, "") - } else { + default: // Should never happen but... logrus.Warnf("Unknown Postgres parameter format code: %#v.", msg) parameters = append(parameters, "") diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 5ee897877df70..11d8a8adc4b1b 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -58,7 +58,7 @@ type Config struct { // StreamEmitter is a non-blocking audit events emitter. StreamEmitter events.StreamEmitter // NewAudit allows to override audit logger in tests. - NewAudit common.NewAuditFn + NewAudit NewAuditFn // TLSConfig is the *tls.Config for this server. TLSConfig *tls.Config // Authorizer is used to authorize requests coming from proxy. @@ -75,6 +75,9 @@ type Config struct { OnHeartbeat func(error) } +// NewAuditFn defines a function that creates an audit logger. +type NewAuditFn func(common.AuditConfig) (common.Audit, error) + // CheckAndSetDefaults makes sure the configuration has the minimum required // to function. func (c *Config) CheckAndSetDefaults(ctx context.Context) error {