diff --git a/cmd/mock-driver/main.go b/cmd/mock-driver/main.go index 0b8a477e..12b64717 100644 --- a/cmd/mock-driver/main.go +++ b/cmd/mock-driver/main.go @@ -40,24 +40,11 @@ func main() { flag.Parse() endpoint := os.Getenv("CSI_ENDPOINT") - if len(endpoint) == 0 { - fmt.Println("CSI_ENDPOINT must be defined and must be a path") - os.Exit(1) - } - if strings.Contains(endpoint, ":") { - fmt.Println("CSI_ENDPOINT must be a unix path") - os.Exit(1) - } - controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT") if len(controllerEndpoint) == 0 { // If empty, set to the common endpoint. controllerEndpoint = endpoint } - if strings.Contains(controllerEndpoint, ":") { - fmt.Println("CSI_CONTROLLER_ENDPOINT must be a unix path") - os.Exit(1) - } // Create mock driver s := service.New(config) @@ -77,16 +64,14 @@ func main() { } // Listen - os.Remove(endpoint) - os.Remove(controllerEndpoint) - l, err := net.Listen("unix", endpoint) + l, cleanup, err := listen(endpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", endpoint, err) os.Exit(1) } - defer os.Remove(endpoint) + defer cleanup() // Start server if err := d.Start(l); err != nil { @@ -129,15 +114,14 @@ func main() { } // Listen controller. - os.Remove(controllerEndpoint) - l, err := net.Listen("unix", controllerEndpoint) + l, cleanupController, err := listen(controllerEndpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", controllerEndpoint, err) os.Exit(1) } - defer os.Remove(controllerEndpoint) + defer cleanupController() // Start controller server. if err = dc.Start(l); err != nil { @@ -148,15 +132,14 @@ func main() { fmt.Println("mock controller driver started") // Listen node. - os.Remove(endpoint) - l, err = net.Listen("unix", endpoint) + l, cleanupNode, err := listen(endpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", endpoint, err) os.Exit(1) } - defer os.Remove(endpoint) + defer cleanupNode() // Start node server. if err = dn.Start(l); err != nil { @@ -182,3 +165,36 @@ func main() { fmt.Println("mock drivers stopped") } } + +func parseEndpoint(ep string) (string, string, error) { + if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") { + s := strings.SplitN(ep, "://", 2) + if s[1] != "" { + return s[0], s[1], nil + } + return "", "", fmt.Errorf("Invalid endpoint: %v", ep) + } + // Assume everything else is a file path for a Unix Domain Socket. + return "unix", ep, nil +} + +func listen(endpoint string) (net.Listener, func(), error) { + proto, addr, err := parseEndpoint(endpoint) + if err != nil { + return nil, nil, err + } + + cleanup := func() {} + if proto == "unix" { + addr = "/" + addr + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow + return nil, nil, fmt.Errorf("%s: %q", addr, err) + } + cleanup = func() { + os.Remove(addr) + } + } + + l, err := net.Listen(proto, addr) + return l, cleanup, err +} diff --git a/hack/e2e.sh b/hack/e2e.sh index e3a8c6cf..8cb59e9c 100755 --- a/hack/e2e.sh +++ b/hack/e2e.sh @@ -4,6 +4,10 @@ TESTARGS=$@ UDS="/tmp/e2e-csi-sanity.sock" UDS_NODE="/tmp/e2e-csi-sanity-node.sock" UDS_CONTROLLER="/tmp/e2e-csi-sanity-ctrl.sock" +# Protocol specified as for net.Listen... +TCP_SERVER="tcp://localhost:7654" +# ... and slightly differently for gRPC. +TCP_CLIENT="dns:///localhost:7654" CSI_ENDPOINTS="$CSI_ENDPOINTS ${UDS}" CSI_MOCK_VERSION="master" @@ -108,6 +112,7 @@ cd cmd/csi-sanity make clean install || exit 1 cd ../.. +runTest "${TCP_SERVER}" "${TCP_CLIENT}" && runTest "${UDS}" "${UDS}" && runTestWithCreds "${UDS}" "${UDS}" && runTestAPI "${UDS}" &&