Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude DstIPs #137

Merged
merged 8 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"os"
"runtime"
"runtime/trace"
"strconv"

"github.com/m-lab/tcp-info/eventsocket"

Expand Down Expand Up @@ -57,6 +56,7 @@ var (
enableTrace bool
outputDir string
excludeSrcPorts = flagx.StringArray{}
excludeDstIPs = flagx.StringArray{}
)

func init() {
Expand All @@ -67,6 +67,7 @@ func init() {
flag.BoolVar(&enableTrace, "trace", false, "Enable trace")
flag.StringVar(&outputDir, "output", "", "Directory in which to put the resulting tree of data. Default is the current directory.")
flag.Var(&excludeSrcPorts, "exclude-srcport", "Exclude snapshots with these local ports from saved archives.")
flag.Var(&excludeDstIPs, "exclude-dstip", "Exclude snapshots with these remote IPs from saved archives.")
}

// NOTES:
Expand Down Expand Up @@ -115,17 +116,23 @@ func main() {
Local: true,
}

if len(excludeDstIPs) != 0 {
for _, dip := range excludeDstIPs {
err := ex.AddDstIP(dip)
if err != nil {
log.Printf("skipping; cannot convert ip %q; %v", dip, err)
continue
}
}
}
if len(excludeSrcPorts) != 0 {
srcPorts := map[uint16]bool{}
for _, port := range excludeSrcPorts {
i, err := strconv.ParseInt(port, 10, 16)
err := ex.AddSrcPort(port)
if err != nil {
log.Printf("skipping; cannot convert %q to integer", port)
log.Printf("skipping; cannot convert port %q; %v", port, err)
continue
}
srcPorts[uint16(i)] = true
}
ex.SrcPorts = srcPorts
}

// Make the saver and construct the message channel, buffering up to 2 batches
Expand Down
52 changes: 52 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,55 @@ func TestMain(t *testing.T) {
// REPS=1 should cause main to run once and then exit.
main()
}

func TestMainWithExcludeOptions(t *testing.T) {
// Write files to a temp directory.
dir, err := ioutil.TempDir("", "TestMain")
rtx.Must(err, "Could not create tempdir")
defer os.RemoveAll(dir)

// Make sure that starting up main() does not cause any panics. There's not
// a lot else we can test, but we can at least make sure that it doesn't
// immediately crash.
for _, v := range []struct{ name, val string }{
{"REPS", "1"},
{"TRACE", "true"},
{"OUTPUT", dir},
{"TCPINFO_EVENTSOCKET", dir + "/eventsock.sock"},
{"PROMETHEUSX_LISTEN_ADDRESS", ":0"},
{"EXCLUDE_SRCPORT", "443"},
{"EXCLUDE_DSTIP", "172.25.0.1"},
} {
cleanup := osx.MustSetenv(v.name, v.val)
defer cleanup()
}

// REPS=1 should cause main to run once and then exit.
main()
}

func TestMainWithBadExcludeOptions(t *testing.T) {
// Write files to a temp directory.
dir, err := ioutil.TempDir("", "TestMain")
rtx.Must(err, "Could not create tempdir")
defer os.RemoveAll(dir)

// Make sure that starting up main() does not cause any panics. There's not
// a lot else we can test, but we can at least make sure that it doesn't
// immediately crash.
for _, v := range []struct{ name, val string }{
{"REPS", "1"},
{"TRACE", "true"},
{"OUTPUT", dir},
{"TCPINFO_EVENTSOCKET", dir + "/eventsock.sock"},
{"PROMETHEUSX_LISTEN_ADDRESS", ":0"},
{"EXCLUDE_SRCPORT", "NOT_AN_INT"},
{"EXCLUDE_DSTIP", ";not-an-ip;"},
} {
cleanup := osx.MustSetenv(v.name, v.val)
defer cleanup()
}

// REPS=1 should cause main to run once and then exit.
main()
}
44 changes: 44 additions & 0 deletions netlink/archival-record.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"flag"
"io"
"log"
"net"
"strconv"
"time"
"unsafe"

Expand Down Expand Up @@ -53,6 +55,42 @@ type ExcludeConfig struct {
Local bool
// SrcPorts excludes connections from specific source ports.
SrcPorts map[uint16]bool
DstIPs map[[16]byte]bool
}

// AddSrcPort adds the given port to the set of source ports to exclude.
func (ex *ExcludeConfig) AddSrcPort(port string) error {
i, err := strconv.ParseInt(port, 10, 16)
if err != nil {
return err
}
if ex.SrcPorts == nil {
ex.SrcPorts = map[uint16]bool{}
}
ex.SrcPorts[uint16(i)] = true
return nil
}

// AddDstIP adds the given dst IP address to the set of destination IPs to exclude.
func (ex *ExcludeConfig) AddDstIP(dst string) error {
ip := net.ParseIP(dst)
if ip == nil {
return errors.New("invalid ip: " + dst)
}
if ex.DstIPs == nil {
ex.DstIPs = map[[16]byte]bool{}
}
key := [16]byte{}
if ip.To4() != nil {
// NOTE: The Linux-native byte position for IPv4 addresses is the first four bytes.
// The net.IP package format uses the last four bytes. Copy the net.IP bytes to a
// new array to generate a key for dstIPs.
copy(key[:], ip[12:])
} else {
copy(key[:], ip[:])
}
ex.DstIPs[key] = true
return nil
}

// ParseRouteAttr parses a byte array into slice of NetlinkRouteAttr struct.
Expand Down Expand Up @@ -94,6 +132,12 @@ func MakeArchivalRecord(msg *NetlinkMessage, exclude *ExcludeConfig) (*ArchivalR
if exclude.Local && (isLocal(idm.ID.SrcIP()) || isLocal(idm.ID.DstIP())) {
return nil, nil
}
if exclude.DstIPs != nil && exclude.DstIPs[idm.ID.IDiagDst] {
// Note: byte-key lookup is preferable for performance than
// net.IP-to-String formatting. And, a byte array can be a map key,
// while a net.IP byte slice cannot.
return nil, nil
}
}

record := ArchivalRecord{RawIDM: raw}
Expand Down
147 changes: 147 additions & 0 deletions netlink/archival-record_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Package netlink contains the bare minimum needed to partially parse netlink messages.
package netlink

import (
"reflect"
"testing"
"unsafe"

"github.com/m-lab/tcp-info/inetdiag"
)

func inet2bytes(inet *inetdiag.InetDiagMsg) []byte {
const sz = int(unsafe.Sizeof(inetdiag.InetDiagMsg{}))
return (*[sz]byte)(unsafe.Pointer(inet))[:]
}

func TestMakeArchivalRecord(t *testing.T) {
id := inetdiag.LinuxSockID{
IDiagSPort: [2]byte{0, 77}, // src port
IDiagSrc: [16]byte{127, 0, 0, 1}, // localhost
IDiagDst: [16]byte{172, 25, 0, 1}, // dst ip
}
tests := []struct {
name string
msg *NetlinkMessage
exclude *ExcludeConfig
want *ArchivalRecord
wantErr bool
}{
{
name: "exclude-local",
msg: &NetlinkMessage{
Header: NlMsghdr{Type: 20},
Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}),
},
exclude: &ExcludeConfig{
Local: true,
},
},
{
name: "exclude-srcport",
msg: &NetlinkMessage{
Header: NlMsghdr{Type: 20},
Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}),
},
exclude: &ExcludeConfig{
SrcPorts: map[uint16]bool{77: true},
},
},
{
name: "exclude-dstip",
msg: &NetlinkMessage{
Header: NlMsghdr{Type: 20},
Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}),
},
exclude: &ExcludeConfig{
DstIPs: map[[16]byte]bool{[16]byte{172, 25, 0, 1}: true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// All cases should return nil.
got, err := MakeArchivalRecord(tt.msg, tt.exclude)
if err != nil {
t.Errorf("MakeArchivalRecord() error = %v, wantErr nil", err)
return
}
if got != nil {
t.Errorf("MakeArchivalRecord() = %v, want nil", got)
}
})
}
}

func TestExcludeConfig_AddSrcPort(t *testing.T) {
tests := []struct {
name string
port string
wantPorts map[uint16]bool
wantErr bool
}{
{
name: "success",
port: "9999",
wantPorts: map[uint16]bool{
9999: true,
},
},
{
name: "error",
port: "not-a-port",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ex := &ExcludeConfig{}
if err := ex.AddSrcPort(tt.port); (err != nil) != tt.wantErr {
t.Errorf("ExcludeConfig.AddSrcPort() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(ex.SrcPorts, tt.wantPorts) {
t.Errorf("ExcludeConfig.SrcPorts = %#v, want %#v", ex.SrcPorts, tt.wantPorts)
}
})
}
}

func TestExcludeConfig_AddDstIP(t *testing.T) {
tests := []struct {
name string
dst string
wantIPs map[[16]byte]bool
wantErr bool
}{
{
name: "success-ipv4",
dst: "172.25.0.1",
wantIPs: map[[16]byte]bool{
[16]byte{172, 25, 0, 1}: true,
},
},
{
name: "success-ipv6",
dst: "fd0a:008d:ba3f:a834::",
wantIPs: map[[16]byte]bool{
[16]byte{0xfd, 0x0a, 0x00, 0x8d, 0xba, 0x3f, 0xa8, 0x34}: true,
},
},
{
name: "error",
dst: ";not-an-ip;",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ex := &ExcludeConfig{}
if err := ex.AddDstIP(tt.dst); (err != nil) != tt.wantErr {
t.Errorf("ExcludeConfig.AddDstIP() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(ex.DstIPs, tt.wantIPs) {
t.Errorf("ExcludeConfig.DstIPs = %#v, want %#v", ex.DstIPs, tt.wantIPs)
}
})
}
}