From 67188aeb944cc5975529356bff77a0b270df1d30 Mon Sep 17 00:00:00 2001
From: Craig Johnston <cj@imti.co>
Date: Sat, 17 Oct 2020 16:34:54 -0700
Subject: [PATCH] host management cleanup

---
 pkg/fwdport/fwdport.go | 61 ++++++++++++++++++++----------------------
 1 file changed, 29 insertions(+), 32 deletions(-)

diff --git a/pkg/fwdport/fwdport.go b/pkg/fwdport/fwdport.go
index cbc0229a..9e1eaf5c 100644
--- a/pkg/fwdport/fwdport.go
+++ b/pkg/fwdport/fwdport.go
@@ -59,6 +59,7 @@ type PortForwardOpts struct {
 	Remote         bool
 	Domain         string
 	HostsParams    *HostsParams
+	Hosts          []string
 	ManualStopChan chan struct{} // Send a signal on this to stop the portforwarding
 	DoneChan       chan struct{} // Listen on this channel for when the shutdown is completed.
 }
@@ -107,7 +108,7 @@ func (pfo *PortForwardOpts) PortForward() error {
 
 	}()
 
-	// Waiting until the pod is runnning
+	// Waiting until the pod is running
 	pod, err := pfo.WaitUntilPodRunning(downstreamStopChannel)
 	if err != nil {
 		pfo.Stop()
@@ -165,7 +166,19 @@ func (pfo *PortForwardOpts) BuildTheHostsParams() {
 	pfo.HostsParams.svcServiceName = svcServiceName
 }
 
-// this method to add hosts obj in /etc/hosts
+// AddHost
+func (pfo *PortForwardOpts) addHost(host string) {
+	// add to list of hostnames for this port-forward
+	pfo.Hosts = append(pfo.Hosts, host)
+
+	// remove host if it already exists in /etc/hosts
+	pfo.Hostfile.Hosts.RemoveHost(host)
+
+	// add host to /etc/hosts
+	pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), host)
+}
+
+// AddHosts adds hostname entries to /etc/hosts
 func (pfo *PortForwardOpts) AddHosts() {
 
 	pfo.Hostfile.Lock()
@@ -173,36 +186,30 @@ func (pfo *PortForwardOpts) AddHosts() {
 
 		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.fullServiceName)
 		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.svcServiceName)
+
 		if pfo.Domain != "" {
-			pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.Service+"."+pfo.Domain)
+			pfo.addHost(pfo.Service + "." + pfo.Domain)
 		}
-		pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.Service)
+		pfo.addHost(pfo.Service)
 
 	} else {
 
 		if pfo.ShortName {
 			if pfo.Domain != "" {
-				pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.localServiceName + "." + pfo.Domain)
-				pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.localServiceName+"."+pfo.Domain)
+				pfo.addHost(pfo.HostsParams.localServiceName + "." + pfo.Domain)
 			}
-			pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.localServiceName)
-			pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.localServiceName)
+			pfo.addHost(pfo.HostsParams.localServiceName)
 		}
 
-		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.fullServiceName)
-		pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.fullServiceName)
-
-		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.svcServiceName)
-		pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.svcServiceName)
+		pfo.addHost(pfo.HostsParams.fullServiceName)
+		pfo.addHost(pfo.HostsParams.svcServiceName)
 
 		if pfo.Domain != "" {
-			pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.nsServiceName + "." + pfo.Domain)
-			pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.nsServiceName+"."+pfo.Domain)
+			pfo.addHost(pfo.HostsParams.nsServiceName + "." + pfo.Domain)
 		}
-		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.nsServiceName)
-		pfo.Hostfile.Hosts.AddHost(pfo.LocalIp.String(), pfo.HostsParams.nsServiceName)
-
+		pfo.addHost(pfo.HostsParams.nsServiceName)
 	}
+
 	err := pfo.Hostfile.Hosts.Save()
 	if err != nil {
 		log.Error("Error saving hosts file", err)
@@ -212,6 +219,7 @@ func (pfo *PortForwardOpts) AddHosts() {
 
 // this method to remove hosts obj in /etc/hosts
 func (pfo *PortForwardOpts) removeHosts() {
+
 	// we should lock the pfo.Hostfile here
 	// because sometimes other goroutine write the *txeh.Hosts
 	pfo.Hostfile.Lock()
@@ -223,21 +231,10 @@ func (pfo *PortForwardOpts) removeHosts() {
 		return
 	}
 
-	if !pfo.Remote {
-		if pfo.Domain != "" {
-			// fmt.Printf("removeHost: %s\r\n", (pfo.HostsParams.localServiceName + "." + pfo.Domain))
-			// fmt.Printf("removeHost: %s\r\n", (pfo.HostsParams.nsServiceName + "." + pfo.Domain))
-			pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.localServiceName + "." + pfo.Domain)
-			pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.nsServiceName + "." + pfo.Domain)
-		}
-		// fmt.Printf("removeHost: %s\r\n", pfo.HostsParams.localServiceName)
-		// fmt.Printf("removeHost: %s\r\n", pfo.HostsParams.nsServiceName)
-		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.localServiceName)
-		pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.nsServiceName)
+	// remove all hosts
+	for _, host := range pfo.Hosts {
+		pfo.Hostfile.Hosts.RemoveHost(host)
 	}
-	// fmt.Printf("removeHost: %s\r\n", pfo.HostsParams.fullServiceName)
-	pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.fullServiceName)
-	pfo.Hostfile.Hosts.RemoveHost(pfo.HostsParams.svcServiceName)
 
 	// fmt.Printf("Delete Host And Save !\r\n")
 	err = pfo.Hostfile.Hosts.Save()