diff --git a/client.go b/client.go index 524d311..050d1f4 100644 --- a/client.go +++ b/client.go @@ -35,7 +35,7 @@ func (c *ConntrackListReq) toWireFormat() []byte { return b } -func connectNetfilter(groups uint32) (int, *syscall.SockaddrNetlink, error) { +func connectNetfilter(bufferSize int, groups uint32) (int, *syscall.SockaddrNetlink, error) { s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_NETFILTER) if err != nil { return 0, nil, err @@ -47,13 +47,22 @@ func connectNetfilter(groups uint32) (int, *syscall.SockaddrNetlink, error) { if err := syscall.Bind(s, lsa); err != nil { return 0, nil, err } + if bufferSize > 0 { + // Speculatively try SO_RCVBUFFORCE which needs CAP_NET_ADMIN + if err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufferSize); err != nil { + // and if that doesn't work fall back to the ordinary SO_RCVBUF + if err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufferSize); err != nil { + return 0, nil, err + } + } + } return s, lsa, nil } // Make syscall asking for all connections. Invoke 'cb' for each connection. -func queryAllConnections(cb func(Conn)) error { - s, lsa, err := connectNetfilter(0) +func queryAllConnections(bufferSize int, cb func(Conn)) error { + s, lsa, err := connectNetfilter(bufferSize, 0) if err != nil { return err } @@ -86,7 +95,7 @@ func queryAllConnections(cb func(Conn)) error { func StreamAllConnections() chan Conn { ch := make(chan Conn, 1) go func() { - queryAllConnections(func(c Conn) { + queryAllConnections(0, func(c Conn) { ch <- c }) close(ch) @@ -96,8 +105,13 @@ func StreamAllConnections() chan Conn { // Lists all the connections that conntrack is tracking. func Connections() ([]Conn, error) { + return ConnectionsSize(0) +} + +// Lists all the connections that conntrack is tracking, using specified netlink buffer size. +func ConnectionsSize(bufferSize int) ([]Conn, error) { var conns []Conn - queryAllConnections(func(c Conn) { + queryAllConnections(bufferSize, func(c Conn) { conns = append(conns, c) }) return conns, nil @@ -107,7 +121,7 @@ func Connections() ([]Conn, error) { func Established() ([]ConnTCP, error) { var conns []ConnTCP local := localIPs() - err := queryAllConnections(func(c Conn) { + err := queryAllConnections(0, func(c Conn) { if c.MsgType != NfctMsgUpdate { fmt.Printf("msg isn't an update: %d\n", c.MsgType) return @@ -128,7 +142,12 @@ func Established() ([]ConnTCP, error) { // Follow gives a channel with all changes. func Follow(flags uint32) (<-chan Conn, func(), error) { - s, _, err := connectNetfilter(flags) + return FollowSize(0, flags) +} + +// Follow gives a channel with all changes, , using specified netlink buffer size. +func FollowSize(bufferSize int, flags uint32) (<-chan Conn, func(), error) { + s, _, err := connectNetfilter(bufferSize, flags) stop := func() { syscall.Close(s) }