diff --git a/main.go b/main.go index e5a5b18..275e21d 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "os" "runtime" "runtime/trace" - "strconv" "github.com/m-lab/tcp-info/eventsocket" @@ -57,6 +56,7 @@ var ( enableTrace bool outputDir string excludeSrcPorts = flagx.StringArray{} + excludeDstIPs = flagx.StringArray{} ) func init() { @@ -67,6 +67,7 @@ func init() { flag.BoolVar(&enableTrace, "trace", false, "Enable trace") flag.StringVar(&outputDir, "output", "", "Directory in which to put the resulting tree of data. Default is the current directory.") flag.Var(&excludeSrcPorts, "exclude-srcport", "Exclude snapshots with these local ports from saved archives.") + flag.Var(&excludeDstIPs, "exclude-dstip", "Exclude snapshots with these remote IPs from saved archives.") } // NOTES: @@ -115,17 +116,23 @@ func main() { Local: true, } + if len(excludeDstIPs) != 0 { + for _, dip := range excludeDstIPs { + err := ex.AddDstIP(dip) + if err != nil { + log.Printf("skipping; cannot convert ip %q; %v", dip, err) + continue + } + } + } if len(excludeSrcPorts) != 0 { - srcPorts := map[uint16]bool{} for _, port := range excludeSrcPorts { - i, err := strconv.ParseInt(port, 10, 16) + err := ex.AddSrcPort(port) if err != nil { - log.Printf("skipping; cannot convert %q to integer", port) + log.Printf("skipping; cannot convert port %q; %v", port, err) continue } - srcPorts[uint16(i)] = true } - ex.SrcPorts = srcPorts } // Make the saver and construct the message channel, buffering up to 2 batches diff --git a/main_test.go b/main_test.go index d780053..0e45b87 100644 --- a/main_test.go +++ b/main_test.go @@ -32,3 +32,55 @@ func TestMain(t *testing.T) { // REPS=1 should cause main to run once and then exit. main() } + +func TestMainWithExcludeOptions(t *testing.T) { + // Write files to a temp directory. + dir, err := ioutil.TempDir("", "TestMain") + rtx.Must(err, "Could not create tempdir") + defer os.RemoveAll(dir) + + // Make sure that starting up main() does not cause any panics. There's not + // a lot else we can test, but we can at least make sure that it doesn't + // immediately crash. + for _, v := range []struct{ name, val string }{ + {"REPS", "1"}, + {"TRACE", "true"}, + {"OUTPUT", dir}, + {"TCPINFO_EVENTSOCKET", dir + "/eventsock.sock"}, + {"PROMETHEUSX_LISTEN_ADDRESS", ":0"}, + {"EXCLUDE_SRCPORT", "443"}, + {"EXCLUDE_DSTIP", "172.25.0.1"}, + } { + cleanup := osx.MustSetenv(v.name, v.val) + defer cleanup() + } + + // REPS=1 should cause main to run once and then exit. + main() +} + +func TestMainWithBadExcludeOptions(t *testing.T) { + // Write files to a temp directory. + dir, err := ioutil.TempDir("", "TestMain") + rtx.Must(err, "Could not create tempdir") + defer os.RemoveAll(dir) + + // Make sure that starting up main() does not cause any panics. There's not + // a lot else we can test, but we can at least make sure that it doesn't + // immediately crash. + for _, v := range []struct{ name, val string }{ + {"REPS", "1"}, + {"TRACE", "true"}, + {"OUTPUT", dir}, + {"TCPINFO_EVENTSOCKET", dir + "/eventsock.sock"}, + {"PROMETHEUSX_LISTEN_ADDRESS", ":0"}, + {"EXCLUDE_SRCPORT", "NOT_AN_INT"}, + {"EXCLUDE_DSTIP", ";not-an-ip;"}, + } { + cleanup := osx.MustSetenv(v.name, v.val) + defer cleanup() + } + + // REPS=1 should cause main to run once and then exit. + main() +} diff --git a/netlink/archival-record.go b/netlink/archival-record.go index 482424d..07664ea 100644 --- a/netlink/archival-record.go +++ b/netlink/archival-record.go @@ -6,10 +6,12 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "flag" "io" "log" "net" + "strconv" "time" "unsafe" @@ -53,6 +55,42 @@ type ExcludeConfig struct { Local bool // SrcPorts excludes connections from specific source ports. SrcPorts map[uint16]bool + DstIPs map[[16]byte]bool +} + +// AddSrcPort adds the given port to the set of source ports to exclude. +func (ex *ExcludeConfig) AddSrcPort(port string) error { + i, err := strconv.ParseInt(port, 10, 16) + if err != nil { + return err + } + if ex.SrcPorts == nil { + ex.SrcPorts = map[uint16]bool{} + } + ex.SrcPorts[uint16(i)] = true + return nil +} + +// AddDstIP adds the given dst IP address to the set of destination IPs to exclude. +func (ex *ExcludeConfig) AddDstIP(dst string) error { + ip := net.ParseIP(dst) + if ip == nil { + return errors.New("invalid ip: " + dst) + } + if ex.DstIPs == nil { + ex.DstIPs = map[[16]byte]bool{} + } + key := [16]byte{} + if ip.To4() != nil { + // NOTE: The Linux-native byte position for IPv4 addresses is the first four bytes. + // The net.IP package format uses the last four bytes. Copy the net.IP bytes to a + // new array to generate a key for dstIPs. + copy(key[:], ip[12:]) + } else { + copy(key[:], ip[:]) + } + ex.DstIPs[key] = true + return nil } // ParseRouteAttr parses a byte array into slice of NetlinkRouteAttr struct. @@ -94,6 +132,12 @@ func MakeArchivalRecord(msg *NetlinkMessage, exclude *ExcludeConfig) (*ArchivalR if exclude.Local && (isLocal(idm.ID.SrcIP()) || isLocal(idm.ID.DstIP())) { return nil, nil } + if exclude.DstIPs != nil && exclude.DstIPs[idm.ID.IDiagDst] { + // Note: byte-key lookup is preferable for performance than + // net.IP-to-String formatting. And, a byte array can be a map key, + // while a net.IP byte slice cannot. + return nil, nil + } } record := ArchivalRecord{RawIDM: raw} diff --git a/netlink/archival-record_test.go b/netlink/archival-record_test.go new file mode 100644 index 0000000..67e7a93 --- /dev/null +++ b/netlink/archival-record_test.go @@ -0,0 +1,147 @@ +// Package netlink contains the bare minimum needed to partially parse netlink messages. +package netlink + +import ( + "reflect" + "testing" + "unsafe" + + "github.com/m-lab/tcp-info/inetdiag" +) + +func inet2bytes(inet *inetdiag.InetDiagMsg) []byte { + const sz = int(unsafe.Sizeof(inetdiag.InetDiagMsg{})) + return (*[sz]byte)(unsafe.Pointer(inet))[:] +} + +func TestMakeArchivalRecord(t *testing.T) { + id := inetdiag.LinuxSockID{ + IDiagSPort: [2]byte{0, 77}, // src port + IDiagSrc: [16]byte{127, 0, 0, 1}, // localhost + IDiagDst: [16]byte{172, 25, 0, 1}, // dst ip + } + tests := []struct { + name string + msg *NetlinkMessage + exclude *ExcludeConfig + want *ArchivalRecord + wantErr bool + }{ + { + name: "exclude-local", + msg: &NetlinkMessage{ + Header: NlMsghdr{Type: 20}, + Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}), + }, + exclude: &ExcludeConfig{ + Local: true, + }, + }, + { + name: "exclude-srcport", + msg: &NetlinkMessage{ + Header: NlMsghdr{Type: 20}, + Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}), + }, + exclude: &ExcludeConfig{ + SrcPorts: map[uint16]bool{77: true}, + }, + }, + { + name: "exclude-dstip", + msg: &NetlinkMessage{ + Header: NlMsghdr{Type: 20}, + Data: inet2bytes(&inetdiag.InetDiagMsg{ID: id}), + }, + exclude: &ExcludeConfig{ + DstIPs: map[[16]byte]bool{[16]byte{172, 25, 0, 1}: true}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // All cases should return nil. + got, err := MakeArchivalRecord(tt.msg, tt.exclude) + if err != nil { + t.Errorf("MakeArchivalRecord() error = %v, wantErr nil", err) + return + } + if got != nil { + t.Errorf("MakeArchivalRecord() = %v, want nil", got) + } + }) + } +} + +func TestExcludeConfig_AddSrcPort(t *testing.T) { + tests := []struct { + name string + port string + wantPorts map[uint16]bool + wantErr bool + }{ + { + name: "success", + port: "9999", + wantPorts: map[uint16]bool{ + 9999: true, + }, + }, + { + name: "error", + port: "not-a-port", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ex := &ExcludeConfig{} + if err := ex.AddSrcPort(tt.port); (err != nil) != tt.wantErr { + t.Errorf("ExcludeConfig.AddSrcPort() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(ex.SrcPorts, tt.wantPorts) { + t.Errorf("ExcludeConfig.SrcPorts = %#v, want %#v", ex.SrcPorts, tt.wantPorts) + } + }) + } +} + +func TestExcludeConfig_AddDstIP(t *testing.T) { + tests := []struct { + name string + dst string + wantIPs map[[16]byte]bool + wantErr bool + }{ + { + name: "success-ipv4", + dst: "172.25.0.1", + wantIPs: map[[16]byte]bool{ + [16]byte{172, 25, 0, 1}: true, + }, + }, + { + name: "success-ipv6", + dst: "fd0a:008d:ba3f:a834::", + wantIPs: map[[16]byte]bool{ + [16]byte{0xfd, 0x0a, 0x00, 0x8d, 0xba, 0x3f, 0xa8, 0x34}: true, + }, + }, + { + name: "error", + dst: ";not-an-ip;", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ex := &ExcludeConfig{} + if err := ex.AddDstIP(tt.dst); (err != nil) != tt.wantErr { + t.Errorf("ExcludeConfig.AddDstIP() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(ex.DstIPs, tt.wantIPs) { + t.Errorf("ExcludeConfig.DstIPs = %#v, want %#v", ex.DstIPs, tt.wantIPs) + } + }) + } +}