diff --git a/Gopkg.lock b/Gopkg.lock index a4faca88..fdc3f16c 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -76,7 +76,7 @@ [[projects]] name = "github.com/vishvananda/netlink" packages = [".","nl"] - revision = "177f1ceba557262b3f1c3aba4df93a29199fb4eb" + revision = "b7fbf1f5291ecf8ae5179d3202e914cb98cfe400" [[projects]] name = "github.com/vishvananda/netns" @@ -96,6 +96,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "52bc4229f63ca78c2086df724034421cf56f51e281edb843f319fb0644af36d3" + inputs-digest = "82781172d7b56c5605cb416f72f21b8cd71ae5f49ef87cfe940e8f7b3d0f3c21" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 8cde9af0..ecc09112 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -114,7 +114,7 @@ [[constraint]] name = "github.com/vishvananda/netlink" - revision = "177f1ceba557262b3f1c3aba4df93a29199fb4eb" + revision = "b7fbf1f5291ecf8ae5179d3202e914cb98cfe400" [[constraint]] name = "github.com/vishvananda/netns" diff --git a/hyperstart.go b/hyperstart.go index c855171a..efd8ccc5 100644 --- a/hyperstart.go +++ b/hyperstart.go @@ -172,7 +172,7 @@ func (h *hyper) buildNetworkInterfacesAndRoutes(pod Pod) ([]hyperstart.NetworkIf return []hyperstart.NetworkIface{}, []hyperstart.Route{}, nil } - netIfaces, err := getIfacesFromNetNs(networkNS.NetNsPath) + netIfaces, err := getIfacesFromNetNsAll(networkNS.NetNsPath) if err != nil { return []hyperstart.NetworkIface{}, []hyperstart.Route{}, err } diff --git a/network.go b/network.go index 4051d5e6..4e588758 100644 --- a/network.go +++ b/network.go @@ -19,9 +19,11 @@ package virtcontainers import ( "errors" "fmt" + "math/rand" "net" "os" "runtime" + "time" types "github.com/containernetworking/cni/pkg/types/current" "github.com/containernetworking/plugins/pkg/ns" @@ -30,10 +32,40 @@ import ( "golang.org/x/sys/unix" ) -// Introduces constants related to network routes. +// NetInterworkingModel defines the network model connecting +// the network interface to the virtual machine. +type NetInterworkingModel int + +const ( + // ModelBridged uses a linux bridge to interconnect + // the container interface to the VM. This is the + // safe default that works for most cases except + // macvlan and ipvlan + ModelBridged NetInterworkingModel = iota + + // ModelMacVtap can be used when the Container network + // interface can be bridged using macvtap + ModelMacVtap + + // ModelEnlightened can be used when the Network plugins + // are enlightened to create VM native interfaces + // when requested by the runtime + // This will be used for vethtap, macvtap, ipvtap + ModelEnlightened +) + +// DefaultNetInterworkingModel is a package level default +// that determines how the VM should be connected to the +// the container network interface +var DefaultNetInterworkingModel = ModelMacVtap + +// Introduces constants related to networking const ( defaultRouteDest = "0.0.0.0/0" defaultRouteLabel = "default" + defaultFilePerms = 0600 + defaultQlen = 1500 + defaultQueues = 8 ) type netIfaceAddrs struct { @@ -45,14 +77,17 @@ type netIfaceAddrs struct { type NetworkInterface struct { Name string HardAddr string + Addrs []netlink.Addr } -// NetworkInterfacePair defines a pair between TAP and virtual network interfaces. +// NetworkInterfacePair defines a pair between VM and virtual network interfaces. type NetworkInterfacePair struct { ID string Name string VirtIface NetworkInterface TAPIface NetworkInterface + NetInterworkingModel + VMFds []*os.File } // NetworkConfig is the network configuration related to a network. @@ -159,7 +194,7 @@ func runNetworkCommon(networkNSPath string, cb func() error) error { func addNetworkCommon(pod Pod, networkNS *NetworkNamespace) error { err := doNetNS(networkNS.NetNsPath, func(_ ns.NetNS) error { for idx := range networkNS.Endpoints { - if err := bridgeNetworkPair(&(networkNS.Endpoints[idx].NetPair)); err != nil { + if err := xconnectVMNetwork(&(networkNS.Endpoints[idx].NetPair), true); err != nil { return err } } @@ -176,7 +211,7 @@ func addNetworkCommon(pod Pod, networkNS *NetworkNamespace) error { func removeNetworkCommon(networkNS NetworkNamespace) error { return doNetNS(networkNS.NetNsPath, func(_ ns.NetNS) error { for _, endpoint := range networkNS.Endpoints { - err := unBridgeNetworkPair(endpoint.NetPair) + err := xconnectVMNetwork(&(endpoint.NetPair), false) if err != nil { return err } @@ -200,6 +235,22 @@ func createLink(netHandle *netlink.Handle, name string, expectedLink netlink.Lin LinkAttrs: netlink.LinkAttrs{Name: name}, Mode: netlink.TUNTAP_MODE_TAP, } + case (&netlink.Macvtap{}).Type(): + qlen := expectedLink.Attrs().TxQLen + if qlen <= 0 { + qlen = defaultQlen + } + newLink = &netlink.Macvtap{ + Macvlan: netlink.Macvlan{ + Mode: netlink.MACVLAN_MODE_BRIDGE, + LinkAttrs: netlink.LinkAttrs{ + Index: expectedLink.Attrs().Index, + Name: name, + TxQLen: qlen, + ParentIndex: expectedLink.Attrs().ParentIndex, + }, + }, + } default: return nil, fmt.Errorf("Unsupported link type %s", expectedLink.Type()) } @@ -230,6 +281,10 @@ func getLinkByName(netHandle *netlink.Handle, name string, expectedLink netlink. if l, ok := link.(*netlink.Veth); ok { return l, nil } + case (&netlink.Macvtap{}).Type(): + if l, ok := link.(*netlink.Macvtap); ok { + return l, nil + } default: return nil, fmt.Errorf("Unsupported link type %s", expectedLink.Type()) } @@ -237,6 +292,193 @@ func getLinkByName(netHandle *netlink.Handle, name string, expectedLink netlink. return nil, fmt.Errorf("Incorrect link type %s, expecting %s", link.Type(), expectedLink.Type()) } +func xconnectVMNetwork(netPair *NetworkInterfacePair, connect bool) error { + switch DefaultNetInterworkingModel { + case ModelBridged: + netPair.NetInterworkingModel = ModelBridged + if connect { + return bridgeNetworkPair(netPair) + } + return unBridgeNetworkPair(*netPair) + case ModelMacVtap: + netPair.NetInterworkingModel = ModelMacVtap + if connect { + return tapNetworkPair(netPair) + } + return untapNetworkPair(*netPair) + case ModelEnlightened: + return fmt.Errorf("Unsupported networking model") + default: + return fmt.Errorf("Invalid networking model") + } +} + +func createMacvtapFds(linkIndex int, queues int) ([]*os.File, error) { + fds := make([]*os.File, queues) + + //mq support + for q := 0; q < queues; q++ { + + tapDev := fmt.Sprintf("/dev/tap%d", linkIndex) + + f, err := os.OpenFile(tapDev, os.O_RDWR, defaultFilePerms) + if err != nil { + cleanupFds(fds, q) + return nil, err + } + fds[q] = f + } + + return fds, nil +} + +// There is a limitation in the linux kernel that prevents a macvtap/macvlan link +// from getting the correct link index when created in a network namespace +// https://github.com/clearcontainers/runtime/issues/708 +// +// Till that bug is fixed we need to pick a random non conflicting index and try to +// create a link. If that fails, we need to try with another. +// All the kernel does not check if the link id conflicts with a link id on the host +// hence we need to offset the link id to prevent any overlaps with the host index +// +// Here the kernel will ensure that there is no race condition + +const hostLinkOffset = 8192 // Host should not have more than 8k interfaces +const linkRange = 0xFFFF // This will allow upto 2^16 containers +const linkRetries = 128 // The numbers of time we try to find a non conflicting index +const macvtapWorkaround = true + +func createMacVtap(netHandle *netlink.Handle, name string, link netlink.Link) (taplink netlink.Link, err error) { + + if !macvtapWorkaround { + taplink, err = createLink(netHandle, name, link) + return + } + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i := 0; i < linkRetries; i++ { + index := hostLinkOffset + (r.Int() & linkRange) + link.Attrs().Index = index + taplink, err = createLink(netHandle, name, link) + if err == nil { + break + } + } + + return +} + +func clearIPs(link netlink.Link, addrs []netlink.Addr) error { + for _, addr := range addrs { + if err := netlink.AddrDel(link, &addr); err != nil { + return err + } + } + return nil +} + +func setIPs(link netlink.Link, addrs []netlink.Addr) error { + for _, addr := range addrs { + if err := netlink.AddrAdd(link, &addr); err != nil { + return err + } + } + return nil +} + +func tapNetworkPair(netPair *NetworkInterfacePair) error { + netHandle, err := netlink.NewHandle() + if err != nil { + return err + } + defer netHandle.Delete() + + vethLink, err := getLinkByName(netHandle, netPair.VirtIface.Name, &netlink.Veth{}) + if err != nil { + return fmt.Errorf("Could not get veth interface: %s: %s", netPair.VirtIface.Name, err) + } + vethLinkAttrs := vethLink.Attrs() + + // Attach the macvtap interface to the underlying container + // interface. Also picks relevant attributes from the parent + tapLink, err := createMacVtap(netHandle, netPair.TAPIface.Name, + &netlink.Macvtap{ + Macvlan: netlink.Macvlan{ + LinkAttrs: netlink.LinkAttrs{ + TxQLen: vethLinkAttrs.TxQLen, + ParentIndex: vethLinkAttrs.Index, + }, + }, + }) + + if err != nil { + return fmt.Errorf("Could not create TAP interface: %s", err) + } + + // Save the veth MAC address to the TAP so that it can later be used + // to build the hypervisor command line. This MAC address has to be + // the one inside the VM in order to avoid any firewall issues. The + // bridge created by the network plugin on the host actually expects + // to see traffic from this MAC address and not another one. + tapHardAddr := vethLinkAttrs.HardwareAddr + netPair.TAPIface.HardAddr = vethLinkAttrs.HardwareAddr.String() + + if err := netHandle.LinkSetMTU(tapLink, vethLinkAttrs.MTU); err != nil { + return fmt.Errorf("Could not set TAP MTU %d: %s", vethLinkAttrs.MTU, err) + } + + hardAddr, err := net.ParseMAC(netPair.VirtIface.HardAddr) + if err != nil { + return err + } + if err := netHandle.LinkSetHardwareAddr(vethLink, hardAddr); err != nil { + return fmt.Errorf("Could not set MAC address %s for veth interface %s: %s", + netPair.VirtIface.HardAddr, netPair.VirtIface.Name, err) + } + + if err := netHandle.LinkSetHardwareAddr(tapLink, tapHardAddr); err != nil { + return fmt.Errorf("Could not set MAC address %s for veth interface %s: %s", + netPair.VirtIface.HardAddr, netPair.VirtIface.Name, err) + } + + if err := netHandle.LinkSetUp(tapLink); err != nil { + return fmt.Errorf("Could not enable TAP %s: %s", netPair.TAPIface.Name, err) + } + + // Clear the IP addresses from the veth interface to prevent ARP conflict + netPair.VirtIface.Addrs, err = netlink.AddrList(vethLink, netlink.FAMILY_V4) + if err != nil { + return fmt.Errorf("Unable to obtain veth IP addresses: %s", err) + } + + if err := clearIPs(vethLink, netPair.VirtIface.Addrs); err != nil { + return fmt.Errorf("Unable to clear veth IP addresses: %s", err) + } + + if err := netHandle.LinkSetUp(vethLink); err != nil { + return fmt.Errorf("Could not enable veth %s: %s", netPair.VirtIface.Name, err) + } + + // Note: The underlying interfaces need to be up prior to fd creation. + + // Setup the multiqueue fds to be consumed by QEMU as macvtap cannot + // be directly connected. + // Ideally we want + // netdev.FDs, err = createMacvtapFds(netdev.ID, int(config.SMP.CPUs)) + + // We do not have global context here, hence a manifest constant + // that matches our minimum vCPU configuration + // Another option is to defer this to ciao qemu library which does have + // global context but cannot handle errors when setting up the network + netPair.VMFds, err = createMacvtapFds(tapLink.Attrs().Index, defaultQueues) + if err != nil { + return fmt.Errorf("Could not setup macvtap fds %s: %s", netPair.TAPIface, err) + } + + return nil +} + func bridgeNetworkPair(netPair *NetworkInterfacePair) error { netHandle, err := netlink.NewHandle() if err != nil { @@ -251,7 +493,7 @@ func bridgeNetworkPair(netPair *NetworkInterfacePair) error { vethLink, err := getLinkByName(netHandle, netPair.VirtIface.Name, &netlink.Veth{}) if err != nil { - return fmt.Errorf("Could not get veth interface: %s", err) + return fmt.Errorf("Could not get veth interface %s : %s", netPair.VirtIface.Name, err) } vethLinkAttrs := vethLink.Attrs() @@ -307,6 +549,37 @@ func bridgeNetworkPair(netPair *NetworkInterfacePair) error { return nil } +func untapNetworkPair(netPair NetworkInterfacePair) error { + netHandle, err := netlink.NewHandle() + if err != nil { + return err + } + defer netHandle.Delete() + + tapLink, err := getLinkByName(netHandle, netPair.TAPIface.Name, &netlink.Macvtap{}) + if err != nil { + return fmt.Errorf("Could not get TAP interface %s: %s", netPair.TAPIface.Name, err) + } + + if err := netHandle.LinkDel(tapLink); err != nil { + return fmt.Errorf("Could not remove TAP %s: %s", netPair.TAPIface.Name, err) + } + + vethLink, err := getLinkByName(netHandle, netPair.VirtIface.Name, &netlink.Veth{}) + if err != nil { + // The veth pair is not totally managed by virtcontainers + virtLog.Warn("Could not get veth interface %s: %s", netPair.VirtIface.Name, err) + } else { + if err := netHandle.LinkSetDown(vethLink); err != nil { + return fmt.Errorf("Could not disable veth %s: %s", netPair.VirtIface.Name, err) + } + } + + // Restore the IPs that were cleared + err = setIPs(vethLink, netPair.VirtIface.Addrs) + return err +} + func unBridgeNetworkPair(netPair NetworkInterfacePair) error { netHandle, err := netlink.NewHandle() if err != nil { @@ -480,7 +753,7 @@ func createNetworkEndpoints(numOfEndpoints int) (endpoints []Endpoint, err error return endpoints, nil } -func getIfacesFromNetNs(networkNSPath string) ([]netIfaceAddrs, error) { +func getIfacesFromNetNsFilter(networkNSPath string, ipFilter bool) ([]netIfaceAddrs, error) { var netIfaces []netIfaceAddrs if networkNSPath == "" { @@ -499,13 +772,15 @@ func getIfacesFromNetNs(networkNSPath string) ([]netIfaceAddrs, error) { return err } - // Ignore unconfigured network interfaces - // These are either base tunnel devices - // that are not namespaced like - // gre0, gretap0, sit0, ipip0, tunl0 - // or incorrectly setup interfaces - if (addrs == nil) || (len(addrs) == 0) { - continue + if ipFilter { + // Ignore unconfigured network interfaces + // These are either base tunnel devices + // that are not namespaced like + // gre0, gretap0, sit0, ipip0, tunl0 + // or incorrectly setup interfaces + if (addrs == nil) || (len(addrs) == 0) { + continue + } } netIface := netIfaceAddrs{ @@ -525,6 +800,16 @@ func getIfacesFromNetNs(networkNSPath string) ([]netIfaceAddrs, error) { return netIfaces, nil } +func getIfacesFromNetNsAll(networkNSPath string) ([]netIfaceAddrs, error) { + // get all interfaces, even those without IP + return getIfacesFromNetNsFilter(networkNSPath, false) +} + +func getIfacesFromNetNs(networkNSPath string) ([]netIfaceAddrs, error) { + // get only the interfaces with valid IP addrsses + return getIfacesFromNetNsFilter(networkNSPath, true) +} + func getNetIfaceByName(name string, netIfaces []netIfaceAddrs) (net.Interface, error) { for _, netIface := range netIfaces { if netIface.iface.Name == name { diff --git a/qemu.go b/qemu.go index befb48fb..7daa63c6 100644 --- a/qemu.go +++ b/qemu.go @@ -310,11 +310,28 @@ func (q *qemu) appendSocket(devices []ciaoQemu.Device, socket Socket) []ciaoQemu return devices } +func networkModelToQemuType(model NetInterworkingModel) ciaoQemu.NetDeviceType { + switch model { + case ModelBridged: + return ciaoQemu.TAP + case ModelMacVtap: + return ciaoQemu.MACVTAP + //case ModelEnlightened: + // Here the Network plugin will create a VM native interface + // which could be MacVtap, IpVtap, SRIOV, veth-tap, vhost-user + // In these cases we will determine the interface type here + // and pass in the native interface through + default: + //TAP should work for most other cases + return ciaoQemu.TAP + } +} + func (q *qemu) appendNetworks(devices []ciaoQemu.Device, endpoints []Endpoint) []ciaoQemu.Device { for idx, endpoint := range endpoints { devices = append(devices, ciaoQemu.NetDevice{ - Type: ciaoQemu.TAP, + Type: networkModelToQemuType(endpoint.NetPair.NetInterworkingModel), Driver: ciaoQemu.VirtioNetPCI, ID: fmt.Sprintf("network-%d", idx), IFName: endpoint.NetPair.TAPIface.Name, @@ -323,6 +340,7 @@ func (q *qemu) appendNetworks(devices []ciaoQemu.Device, endpoints []Endpoint) [ Script: "no", VHost: true, DisableModern: q.nestedRun, + FDs: endpoint.NetPair.VMFds, }, ) } diff --git a/utils.go b/utils.go index 2855e332..1ea59c6a 100644 --- a/utils.go +++ b/utils.go @@ -19,6 +19,7 @@ package virtcontainers import ( "crypto/rand" "fmt" + "os" "os/exec" ) @@ -64,3 +65,16 @@ func reverseString(s string) string { return string(r) } + +func cleanupFds(fds []*os.File, numFds int) { + + maxFds := len(fds) + + if numFds < maxFds { + maxFds = numFds + } + + for i := 0; i < maxFds; i++ { + _ = fds[i].Close() + } +} diff --git a/vendor/github.com/vishvananda/netlink/handle_linux.go b/vendor/github.com/vishvananda/netlink/handle_linux.go index a04ceae6..d37b087c 100644 --- a/vendor/github.com/vishvananda/netlink/handle_linux.go +++ b/vendor/github.com/vishvananda/netlink/handle_linux.go @@ -45,12 +45,27 @@ func (h *Handle) SetSocketTimeout(to time.Duration) error { } tv := syscall.NsecToTimeval(to.Nanoseconds()) for _, sh := range h.sockets { - fd := sh.Socket.GetFd() - err := syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv) - if err != nil { + if err := sh.Socket.SetSendTimeout(&tv); err != nil { return err } - err = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &tv) + if err := sh.Socket.SetReceiveTimeout(&tv); err != nil { + return err + } + } + return nil +} + +// SetSocketReceiveBufferSize sets the receive buffer size for each +// socket in the netlink handle. The maximum value is capped by +// /proc/sys/net/core/rmem_max. +func (h *Handle) SetSocketReceiveBufferSize(size int, force bool) error { + opt := syscall.SO_RCVBUF + if force { + opt = syscall.SO_RCVBUFFORCE + } + for _, sh := range h.sockets { + fd := sh.Socket.GetFd() + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, opt, size) if err != nil { return err } @@ -58,6 +73,24 @@ func (h *Handle) SetSocketTimeout(to time.Duration) error { return nil } +// GetSocketReceiveBufferSize gets the receiver buffer size for each +// socket in the netlink handle. The retrieved value should be the +// double to the one set for SetSocketReceiveBufferSize. +func (h *Handle) GetSocketReceiveBufferSize() ([]int, error) { + results := make([]int, len(h.sockets)) + i := 0 + for _, sh := range h.sockets { + fd := sh.Socket.GetFd() + size, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) + if err != nil { + return nil, err + } + results[i] = size + i++ + } + return results, nil +} + // NewHandle returns a netlink handle on the network namespace // specified by ns. If ns=netns.None(), current network namespace // will be assumed diff --git a/vendor/github.com/vishvananda/netlink/handle_test.go b/vendor/github.com/vishvananda/netlink/handle_test.go index e7a7f86f..5356de49 100644 --- a/vendor/github.com/vishvananda/netlink/handle_test.go +++ b/vendor/github.com/vishvananda/netlink/handle_test.go @@ -132,6 +132,31 @@ func TestHandleTimeout(t *testing.T) { } } +func TestHandleReceiveBuffer(t *testing.T) { + h, err := NewHandle() + if err != nil { + t.Fatal(err) + } + defer h.Delete() + if err := h.SetSocketReceiveBufferSize(65536, false); err != nil { + t.Fatal(err) + } + sizes, err := h.GetSocketReceiveBufferSize() + if err != nil { + t.Fatal(err) + } + if len(sizes) != len(h.sockets) { + t.Fatalf("Unexpected number of socket buffer sizes: %d (expected %d)", + len(sizes), len(h.sockets)) + } + for _, s := range sizes { + if s < 65536 || s > 2*65536 { + t.Fatalf("Unexpected socket receive buffer size: %d (expected around %d)", + s, 65536) + } + } +} + func verifySockTimeVal(t *testing.T, fd int, tv syscall.Timeval) { var ( tr syscall.Timeval diff --git a/vendor/github.com/vishvananda/netlink/link.go b/vendor/github.com/vishvananda/netlink/link.go index 59f7ba52..5aa3a179 100644 --- a/vendor/github.com/vishvananda/netlink/link.go +++ b/vendor/github.com/vishvananda/netlink/link.go @@ -37,6 +37,7 @@ type LinkAttrs struct { EncapType string Protinfo *Protinfo OperState LinkOperState + NetNsID int } // LinkOperState represents the values of the IFLA_OPERSTATE link diff --git a/vendor/github.com/vishvananda/netlink/link_linux.go b/vendor/github.com/vishvananda/netlink/link_linux.go index fd8e4cda..18751453 100644 --- a/vendor/github.com/vishvananda/netlink/link_linux.go +++ b/vendor/github.com/vishvananda/netlink/link_linux.go @@ -851,6 +851,10 @@ func (h *Handle) linkModify(link Link, flags int) error { msg.Change |= syscall.IFF_MULTICAST msg.Flags |= syscall.IFF_MULTICAST } + if base.Index != 0 { + msg.Index = int32(base.Index) + } + req.AddData(msg) if base.ParentIndex != 0 { @@ -1268,6 +1272,8 @@ func LinkDeserialize(hdr *syscall.NlMsghdr, m []byte) (Link, error) { } case syscall.IFLA_OPERSTATE: base.OperState = LinkOperState(uint8(attr.Value[0])) + case nl.IFLA_LINK_NETNSID: + base.NetNsID = int(native.Uint32(attr.Value[0:4])) } } diff --git a/vendor/github.com/vishvananda/netlink/link_test.go b/vendor/github.com/vishvananda/netlink/link_test.go index a78e09da..f1fb0e06 100644 --- a/vendor/github.com/vishvananda/netlink/link_test.go +++ b/vendor/github.com/vishvananda/netlink/link_test.go @@ -38,6 +38,12 @@ func testLinkAddDel(t *testing.T, link Link) { rBase := result.Attrs() + if base.Index != 0 { + if base.Index != rBase.Index { + t.Fatalf("index is %d, should be %d", rBase.Index, base.Index) + } + } + if vlan, ok := link.(*Vlan); ok { other, ok := result.(*Vlan) if !ok { @@ -260,6 +266,13 @@ func compareVxlan(t *testing.T, expected, actual *Vxlan) { } } +func TestLinkAddDelWithIndex(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + testLinkAddDel(t, &Dummy{LinkAttrs{Index: 1000, Name: "foo"}}) +} + func TestLinkAddDelDummy(t *testing.T) { tearDown := setUpNetlinkTest(t) defer tearDown() diff --git a/vendor/github.com/vishvananda/netlink/nl/nl_linux.go b/vendor/github.com/vishvananda/netlink/nl/nl_linux.go index 1329acd8..72f7f6af 100644 --- a/vendor/github.com/vishvananda/netlink/nl/nl_linux.go +++ b/vendor/github.com/vishvananda/netlink/nl/nl_linux.go @@ -621,6 +621,20 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) { return syscall.ParseNetlinkMessage(rb) } +// SetSendTimeout allows to set a send timeout on the socket +func (s *NetlinkSocket) SetSendTimeout(timeout *syscall.Timeval) error { + // Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine + // remains stuck on a send on a closed fd + return syscall.SetsockoptTimeval(int(s.fd), syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, timeout) +} + +// SetReceiveTimeout allows to set a receive timeout on the socket +func (s *NetlinkSocket) SetReceiveTimeout(timeout *syscall.Timeval) error { + // Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine + // remains stuck on a recvmsg on a closed fd + return syscall.SetsockoptTimeval(int(s.fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout) +} + func (s *NetlinkSocket) GetPid() (uint32, error) { fd := int(atomic.LoadInt32(&s.fd)) lsa, err := syscall.Getsockname(fd) diff --git a/vendor/github.com/vishvananda/netlink/nl/nl_linux_test.go b/vendor/github.com/vishvananda/netlink/nl/nl_linux_test.go index 521a7ef3..b88b560b 100644 --- a/vendor/github.com/vishvananda/netlink/nl/nl_linux_test.go +++ b/vendor/github.com/vishvananda/netlink/nl/nl_linux_test.go @@ -7,6 +7,7 @@ import ( "reflect" "syscall" "testing" + "time" ) type testSerializer interface { @@ -60,3 +61,39 @@ func TestIfInfomsgDeserializeSerialize(t *testing.T) { msg := DeserializeIfInfomsg(orig) testDeserializeSerialize(t, orig, safemsg, msg) } + +func TestIfSocketCloses(t *testing.T) { + nlSock, err := Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH) + if err != nil { + t.Fatalf("Error on creating the socket: %v", err) + } + nlSock.SetReceiveTimeout(&syscall.Timeval{Sec: 2, Usec: 0}) + endCh := make(chan error) + go func(sk *NetlinkSocket, endCh chan error) { + endCh <- nil + for { + _, err := sk.Receive() + // Receive returned because of a timeout and the FD == -1 means that the socket got closed + if err == syscall.EAGAIN && nlSock.GetFd() == -1 { + endCh <- err + return + } + } + }(nlSock, endCh) + + // first receive nil + if msg := <-endCh; msg != nil { + t.Fatalf("Expected nil instead got: %v", msg) + } + // this to guarantee that the receive is invoked before the close + time.Sleep(4 * time.Second) + + // Close the socket + nlSock.Close() + + // Expect to have an error + msg := <-endCh + if msg == nil { + t.Fatalf("Expected error instead received nil") + } +}