Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for listening to audit multicast group #12

Merged
merged 2 commits into from
Sep 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Added

- Add support for listening for audit messages using a multicast group. #9

### Changed

### Deprecated
Expand Down
22 changes: 21 additions & 1 deletion audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ const (
AuditSet
)

// Netlink groups.
const (
NetlinkGroupNone = iota // Group 0 not used
NetlinkGroupReadLog // "best effort" read only socket
)

// WaitMode is a flag to control the behavior of methods that abstract
// asynchronous communication for the caller.
type WaitMode uint8
Expand Down Expand Up @@ -72,13 +78,27 @@ type AuditClient struct {
Netlink NetlinkSendReceiver
}

// NewMulticastAuditClient creates a new AuditClient that binds to the multicast
// socket subscribes to the audit group. The process should have the
// CAP_AUDIT_READ capability to use this. This audit client should not be used
// for command and control. The resp parameter is optional. If provided resp
// will receive a copy of all data read from the netlink socket. This is useful
// for debugging purposes.
func NewMulticastAuditClient(resp io.Writer) (*AuditClient, error) {
return newAuditClient(NetlinkGroupReadLog, resp)
}

// NewAuditClient creates a new AuditClient. The resp parameter is optional. If
// provided resp will receive a copy of all data read from the netlink socket.
// This is useful for debugging purposes.
func NewAuditClient(resp io.Writer) (*AuditClient, error) {
return newAuditClient(NetlinkGroupNone, resp)
}

func newAuditClient(netlinkGroups uint32, resp io.Writer) (*AuditClient, error) {
buf := make([]byte, syscall.NLMSG_HDRLEN+AuditMessageMaxLength)

netlink, err := NewNetlinkClient(syscall.NETLINK_AUDIT, buf, resp)
netlink, err := NewNetlinkClient(syscall.NETLINK_AUDIT, netlinkGroups, buf, resp)
if err != nil {
return nil, err
}
Expand Down
35 changes: 35 additions & 0 deletions audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,41 @@ func TestAuditClientSetBacklogLimit(t *testing.T) {
assert.EqualValues(t, limit, status.BacklogLimit)
}

func TestMulticastAuditClient(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("must be root to bind to netlink audit socket")
}

var dumper io.WriteCloser
if *hexdump {
dumper = hex.Dumper(os.Stdout)
defer dumper.Close()
}

// Start the testing.
client, err := NewMulticastAuditClient(dumper)
if err != nil {
t.Fatal(err)
}
defer client.Close()

// Receive (likely no messages will be received).
var msgCount int
for i := 0; i < 5; i++ {
msg, err := client.Receive(true)
if err == syscall.EAGAIN {
time.Sleep(500 * time.Millisecond)
continue
} else if err != nil {
t.Fatal(err)
} else {
t.Logf("Received: type=%v, msg=%v", msg.Type, string(msg.Data))
msgCount++
}
}
t.Logf("received %d messages", msgCount)
}

func TestAuditClientReceive(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("must be root to set audit port id")
Expand Down
83 changes: 50 additions & 33 deletions cmd/audit/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ import (
)

var (
fs = flag.NewFlagSet("audit", flag.ExitOnError)
debug = fs.Bool("d", false, "enable debug output to stderr")
diag = fs.String("diag", "", "dump raw information from kernel to file")
rate = fs.Uint("rate", 0, "rate limit in kernel (default 0, no rate limit)")
backlog = fs.Uint("backlog", 8192, "backlog limit")
fs = flag.NewFlagSet("audit", flag.ExitOnError)
debug = fs.Bool("d", false, "enable debug output to stderr")
diag = fs.String("diag", "", "dump raw information from kernel to file")
rate = fs.Uint("rate", 0, "rate limit in kernel (default 0, no rate limit)")
backlog = fs.Uint("backlog", 8192, "backlog limit")
receiveOnly = fs.Bool("ro", false, "receive only using multicast, requires kernel 3.16+")
)

func enableLogger() {
Expand Down Expand Up @@ -81,45 +82,61 @@ func read() error {
}

log.Debugln("starting netlink client")
client, err := libaudit.NewAuditClient(diagWriter)
if err != nil {
return err
}

status, err := client.GetStatus()
if err != nil {
return errors.Wrap(err, "failed to get audit status")
}
log.Infof("received audit status=%+v", status)
var err error
var client *libaudit.AuditClient
if *receiveOnly {
client, err = libaudit.NewMulticastAuditClient(diagWriter)
if err != nil {
return errors.Wrap(err, "failed to create receive-only audit client")
}
defer client.Close()
} else {
client, err = libaudit.NewAuditClient(diagWriter)
if err != nil {
return errors.Wrap(err, "failed to create audit client")
}
defer client.Close()

if status.Enabled == 0 {
log.Debugln("enabling auditing in the kernel")
if err = client.SetEnabled(true, libaudit.WaitForReply); err != nil {
return errors.Wrap(err, "failed to set enabled=true")
status, err := client.GetStatus()
if err != nil {
return errors.Wrap(err, "failed to get audit status")
}
}
log.Infof("received audit status=%+v", status)

if status.RateLimit != uint32(*rate) {
log.Debugf("setting rate limit in kernel to %v", *rate)
if err = client.SetRateLimit(uint32(*rate), libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set rate limit to unlimited")
if status.Enabled == 0 {
log.Debugln("enabling auditing in the kernel")
if err = client.SetEnabled(true, libaudit.WaitForReply); err != nil {
return errors.Wrap(err, "failed to set enabled=true")
}
}
}

if status.BacklogLimit != uint32(*backlog) {
log.Debugf("setting backlog limit in kernel to %v", *backlog)
if err = client.SetBacklogLimit(uint32(*backlog), libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set backlog limit")
if status.RateLimit != uint32(*rate) {
log.Debugf("setting rate limit in kernel to %v", *rate)
if err = client.SetRateLimit(uint32(*rate), libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set rate limit to unlimited")
}
}

if status.BacklogLimit != uint32(*backlog) {
log.Debugf("setting backlog limit in kernel to %v", *backlog)
if err = client.SetBacklogLimit(uint32(*backlog), libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set backlog limit")
}
}
}

log.Debugf("sending message to kernel registering our PID (%v) as the audit daemon", os.Getpid())
if err = client.SetPID(libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set audit PID")
log.Debugf("sending message to kernel registering our PID (%v) as the audit daemon", os.Getpid())
if err = client.SetPID(libaudit.NoWait); err != nil {
return errors.Wrap(err, "failed to set audit PID")
}
}

return receive(client)
}

func receive(r *libaudit.AuditClient) error {
for {
rawEvent, err := client.Receive(false)
rawEvent, err := r.Receive(false)
if err != nil {
return errors.Wrap(err, "receive failed")
}
Expand Down
6 changes: 3 additions & 3 deletions netlink.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ type NetlinkClient struct {
// (this is useful for debugging).
//
// The returned NetlinkClient must be closed with Close() when finished.
func NewNetlinkClient(proto int, readBuf []byte, resp io.Writer) (*NetlinkClient, error) {
func NewNetlinkClient(proto int, groups uint32, readBuf []byte, resp io.Writer) (*NetlinkClient, error) {
s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, proto)
if err != nil {
return nil, err
}

src := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
src := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: groups}
if err = syscall.Bind(s, src); err != nil {
syscall.Close(s)
return nil, err
return nil, errors.Wrap(err, "bind failed")
}

pid, err := getPortID(s)
Expand Down
4 changes: 2 additions & 2 deletions netlink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
var _ NetlinkSendReceiver = &NetlinkClient{}

func TestNewNetlinkClient(t *testing.T) {
c, err := NewNetlinkClient(syscall.NETLINK_AUDIT, nil, nil)
c, err := NewNetlinkClient(syscall.NETLINK_AUDIT, 0, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -39,7 +39,7 @@ func TestNewNetlinkClient(t *testing.T) {
// First PID assigned by the kernel will be our actual PID.
assert.EqualValues(t, os.Getpid(), c.pid)

c2, err := NewNetlinkClient(syscall.NETLINK_AUDIT, nil, nil)
c2, err := NewNetlinkClient(syscall.NETLINK_AUDIT, 0, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down