From f7b1a0e6f1f2ab0a61697062d841d9948161e3bc Mon Sep 17 00:00:00 2001 From: Andrew Kroh Date: Fri, 1 Sep 2017 13:32:00 -0400 Subject: [PATCH 1/2] Add support for listening to audit multicast group This adds `NewMulticastAuditClient` that creates a client that listens to the audit multicast group that was added in kernel 3.16. Closes #9 --- CHANGELOG.md | 2 ++ audit.go | 22 ++++++++++++- audit_test.go | 35 ++++++++++++++++++++ cmd/audit/audit.go | 82 +++++++++++++++++++++++++++------------------- netlink.go | 6 ++-- netlink_test.go | 4 +-- 6 files changed, 112 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1afa7a..412aaac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added +- Add support for listening for audit messages using a multicast group. #9 + ### Changed ### Deprecated diff --git a/audit.go b/audit.go index db42732..cb3dfbc 100644 --- a/audit.go +++ b/audit.go @@ -43,6 +43,12 @@ const ( AuditSet ) +// Netlink groups. +const ( + NetlinkGroupNone = iota // Group 0 not used + NetlinkGroupReadLog // "best effort" read only socket +) + // WaitMode is a flag to control the behavior of methods that abstract // asynchronous communication for the caller. type WaitMode uint8 @@ -72,13 +78,27 @@ type AuditClient struct { Netlink NetlinkSendReceiver } +// NewMulticastAuditClient creates a new AuditClient that binds to the multicast +// socket subscribes to the audit group. The process should have the +// CAP_AUDIT_READ capability to use this. This audit client should not be used +// for command and control. The resp parameter is optional. If provided resp +// will receive a copy of all data read from the netlink socket. This is useful +// for debugging purposes. +func NewMulticastAuditClient(resp io.Writer) (*AuditClient, error) { + return newAuditClient(NetlinkGroupReadLog, resp) +} + // NewAuditClient creates a new AuditClient. The resp parameter is optional. If // provided resp will receive a copy of all data read from the netlink socket. // This is useful for debugging purposes. func NewAuditClient(resp io.Writer) (*AuditClient, error) { + return newAuditClient(NetlinkGroupNone, resp) +} + +func newAuditClient(netlinkGroups uint32, resp io.Writer) (*AuditClient, error) { buf := make([]byte, syscall.NLMSG_HDRLEN+AuditMessageMaxLength) - netlink, err := NewNetlinkClient(syscall.NETLINK_AUDIT, buf, resp) + netlink, err := NewNetlinkClient(syscall.NETLINK_AUDIT, netlinkGroups, buf, resp) if err != nil { return nil, err } diff --git a/audit_test.go b/audit_test.go index bb46fed..36b679e 100644 --- a/audit_test.go +++ b/audit_test.go @@ -387,6 +387,41 @@ func TestAuditClientSetBacklogLimit(t *testing.T) { assert.EqualValues(t, limit, status.BacklogLimit) } +func TestMulticastAuditClient(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("must be root to bind to netlink audit socket") + } + + var dumper io.WriteCloser + if *hexdump { + dumper = hex.Dumper(os.Stdout) + defer dumper.Close() + } + + // Start the testing. + client, err := NewMulticastAuditClient(dumper) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + // Receive (likely no messages will be received). + var msgCount int + for i := 0; i < 5; i++ { + msg, err := client.Receive(true) + if err == syscall.EAGAIN { + time.Sleep(500 * time.Millisecond) + continue + } else if err != nil { + t.Fatal(err) + } else { + t.Logf("Received: type=%v, msg=%v", msg.Type, string(msg.Data)) + msgCount++ + } + } + t.Logf("received %d messages", msgCount) +} + func TestAuditClientReceive(t *testing.T) { if os.Geteuid() != 0 { t.Skip("must be root to set audit port id") diff --git a/cmd/audit/audit.go b/cmd/audit/audit.go index 0008496..9806313 100644 --- a/cmd/audit/audit.go +++ b/cmd/audit/audit.go @@ -34,11 +34,12 @@ import ( ) var ( - fs = flag.NewFlagSet("audit", flag.ExitOnError) - debug = fs.Bool("d", false, "enable debug output to stderr") - diag = fs.String("diag", "", "dump raw information from kernel to file") - rate = fs.Uint("rate", 0, "rate limit in kernel (default 0, no rate limit)") - backlog = fs.Uint("backlog", 8192, "backlog limit") + fs = flag.NewFlagSet("audit", flag.ExitOnError) + debug = fs.Bool("d", false, "enable debug output to stderr") + diag = fs.String("diag", "", "dump raw information from kernel to file") + rate = fs.Uint("rate", 0, "rate limit in kernel (default 0, no rate limit)") + backlog = fs.Uint("backlog", 8192, "backlog limit") + receiveOnly = fs.Bool("ro", false, "receive only using multicast, requires kernel 3.16+") ) func enableLogger() { @@ -81,45 +82,60 @@ func read() error { } log.Debugln("starting netlink client") - client, err := libaudit.NewAuditClient(diagWriter) - if err != nil { - return err - } - status, err := client.GetStatus() - if err != nil { - return errors.Wrap(err, "failed to get audit status") - } - log.Infof("received audit status=%+v", status) + var err error + var client *libaudit.AuditClient + if *receiveOnly { + client, err = libaudit.NewMulticastAuditClient(diagWriter) + if err != nil { + return errors.Wrap(err, "failed to create receive-only audit client") + } + } else { + client, err = libaudit.NewAuditClient(diagWriter) + if err != nil { + return errors.Wrap(err, "failed to create audit client") + } - if status.Enabled == 0 { - log.Debugln("enabling auditing in the kernel") - if err = client.SetEnabled(true, libaudit.WaitForReply); err != nil { - return errors.Wrap(err, "failed to set enabled=true") + status, err := client.GetStatus() + if err != nil { + return errors.Wrap(err, "failed to get audit status") } - } + log.Infof("received audit status=%+v", status) - if status.RateLimit != uint32(*rate) { - log.Debugf("setting rate limit in kernel to %v", *rate) - if err = client.SetRateLimit(uint32(*rate), libaudit.NoWait); err != nil { - return errors.Wrap(err, "failed to set rate limit to unlimited") + if status.Enabled == 0 { + log.Debugln("enabling auditing in the kernel") + if err = client.SetEnabled(true, libaudit.WaitForReply); err != nil { + return errors.Wrap(err, "failed to set enabled=true") + } } - } - if status.BacklogLimit != uint32(*backlog) { - log.Debugf("setting backlog limit in kernel to %v", *backlog) - if err = client.SetBacklogLimit(uint32(*backlog), libaudit.NoWait); err != nil { - return errors.Wrap(err, "failed to set backlog limit") + if status.RateLimit != uint32(*rate) { + log.Debugf("setting rate limit in kernel to %v", *rate) + if err = client.SetRateLimit(uint32(*rate), libaudit.NoWait); err != nil { + return errors.Wrap(err, "failed to set rate limit to unlimited") + } + } + + if status.BacklogLimit != uint32(*backlog) { + log.Debugf("setting backlog limit in kernel to %v", *backlog) + if err = client.SetBacklogLimit(uint32(*backlog), libaudit.NoWait); err != nil { + return errors.Wrap(err, "failed to set backlog limit") + } } - } - log.Debugf("sending message to kernel registering our PID (%v) as the audit daemon", os.Getpid()) - if err = client.SetPID(libaudit.NoWait); err != nil { - return errors.Wrap(err, "failed to set audit PID") + log.Debugf("sending message to kernel registering our PID (%v) as the audit daemon", os.Getpid()) + if err = client.SetPID(libaudit.NoWait); err != nil { + return errors.Wrap(err, "failed to set audit PID") + } } + defer client.Close() + + return receive(client) +} +func receive(r *libaudit.AuditClient) error { for { - rawEvent, err := client.Receive(false) + rawEvent, err := r.Receive(false) if err != nil { return errors.Wrap(err, "receive failed") } diff --git a/netlink.go b/netlink.go index aa111d8..1300245 100644 --- a/netlink.go +++ b/netlink.go @@ -73,16 +73,16 @@ type NetlinkClient struct { // (this is useful for debugging). // // The returned NetlinkClient must be closed with Close() when finished. -func NewNetlinkClient(proto int, readBuf []byte, resp io.Writer) (*NetlinkClient, error) { +func NewNetlinkClient(proto int, groups uint32, readBuf []byte, resp io.Writer) (*NetlinkClient, error) { s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, proto) if err != nil { return nil, err } - src := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} + src := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: groups} if err = syscall.Bind(s, src); err != nil { syscall.Close(s) - return nil, err + return nil, errors.Wrap(err, "bind failed") } pid, err := getPortID(s) diff --git a/netlink_test.go b/netlink_test.go index b8ce2b0..8c654dc 100644 --- a/netlink_test.go +++ b/netlink_test.go @@ -28,7 +28,7 @@ import ( var _ NetlinkSendReceiver = &NetlinkClient{} func TestNewNetlinkClient(t *testing.T) { - c, err := NewNetlinkClient(syscall.NETLINK_AUDIT, nil, nil) + c, err := NewNetlinkClient(syscall.NETLINK_AUDIT, 0, nil, nil) if err != nil { t.Fatal(err) } @@ -39,7 +39,7 @@ func TestNewNetlinkClient(t *testing.T) { // First PID assigned by the kernel will be our actual PID. assert.EqualValues(t, os.Getpid(), c.pid) - c2, err := NewNetlinkClient(syscall.NETLINK_AUDIT, nil, nil) + c2, err := NewNetlinkClient(syscall.NETLINK_AUDIT, 0, nil, nil) if err != nil { t.Fatal(err) } From 8f4a68d5618e2b76fb760fbed017255a55a74f4c Mon Sep 17 00:00:00 2001 From: Andrew Kroh Date: Thu, 7 Sep 2017 16:05:15 -0400 Subject: [PATCH 2/2] Move the "defer Close" closer to the client initialization --- cmd/audit/audit.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/audit/audit.go b/cmd/audit/audit.go index 9806313..4bb8233 100644 --- a/cmd/audit/audit.go +++ b/cmd/audit/audit.go @@ -90,11 +90,13 @@ func read() error { if err != nil { return errors.Wrap(err, "failed to create receive-only audit client") } + defer client.Close() } else { client, err = libaudit.NewAuditClient(diagWriter) if err != nil { return errors.Wrap(err, "failed to create audit client") } + defer client.Close() status, err := client.GetStatus() if err != nil { @@ -128,7 +130,6 @@ func read() error { return errors.Wrap(err, "failed to set audit PID") } } - defer client.Close() return receive(client) }