diff --git a/cmd/root.go b/cmd/root.go index 8ba1e41..a1765a3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -47,11 +47,13 @@ Please do not use that tool with devil needs. } }() - if opts.FloodDurationSeconds != -1 { - <-time.After(time.Duration(opts.FloodDurationSeconds) * time.Second) - shouldStop <- true - close(shouldStop) - } + go func() { + if opts.FloodDurationSeconds != -1 { + <-time.After(time.Duration(opts.FloodDurationSeconds) * time.Second) + shouldStop <- true + close(shouldStop) + } + }() for { select { diff --git a/internal/raw/raw.go b/internal/raw/raw.go index f644d10..ad93db1 100644 --- a/internal/raw/raw.go +++ b/internal/raw/raw.go @@ -17,7 +17,7 @@ func init() { } // StartFlooding does the heavy lifting, starts the flood -func StartFlooding(shouldStop chan bool, destinationHost string, destinationPort, payloadLength int, floodType string) error { +func StartFlooding(stopChan chan bool, destinationHost string, destinationPort, payloadLength int, floodType string) error { var ( ipHeader *ipv4.Header packetConn net.PacketConn @@ -41,7 +41,7 @@ func StartFlooding(shouldStop chan bool, destinationHost string, destinationPort for { select { - case <-shouldStop: + case <-stopChan: return nil default: tcpPacket := buildTcpPacket(srcPorts[rand.Intn(len(srcPorts))], destinationPort, floodType) diff --git a/internal/raw/raw_test.go b/internal/raw/raw_test.go index 70ee422..5bb3459 100644 --- a/internal/raw/raw_test.go +++ b/internal/raw/raw_test.go @@ -18,45 +18,35 @@ func TestStartFlooding(t *testing.T) { srcMacAddr, dstMacAddr []byte }{ {"100byte_syn", "syn", 10, srcPorts[rand.Intn(len(srcPorts))], - 443, 100, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", + 443, 1000, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", macAddrs[rand.Intn(len(macAddrs))], macAddrs[rand.Intn(len(macAddrs))]}, { "100byte_ack", "ack", 10, srcPorts[rand.Intn(len(srcPorts))], - 443, 100, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", + 443, 1000, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", macAddrs[rand.Intn(len(macAddrs))], macAddrs[rand.Intn(len(macAddrs))], }, { "100byte_synack", "synAck", 10, srcPorts[rand.Intn(len(srcPorts))], - 443, 100, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", + 443, 1000, srcIps[rand.Intn(len(srcIps))], "213.238.175.187", macAddrs[rand.Intn(len(macAddrs))], macAddrs[rand.Intn(len(macAddrs))], }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - shouldStop := make(chan bool) + stopChan := make(chan bool) t.Logf("starting flood, caseName=%s, floodType=%s, floodMilliSeconds=%d\n", tc.name, tc.floodType, tc.floodMilliSeconds) - go func() { - err := StartFlooding(shouldStop, tc.dstIp, tc.dstPort, tc.payloadLength, tc.floodType) + go func(stopChan chan bool, dstIp string, dstPort int, payloadLength int, floodType string) { + err := StartFlooding(stopChan, dstIp, dstPort, payloadLength, floodType) if err != nil { t.Errorf("an error occured on flooding process: %s\n", err.Error()) return } - }() + }(stopChan, tc.dstIp, tc.dstPort, tc.payloadLength, tc.floodType) <-time.After(time.Duration(tc.floodMilliSeconds) * time.Millisecond) - shouldStop <- true - close(shouldStop) - - for { - select { - case <-shouldStop: - t.Logf("\nshouldStop channel received a signal, stopping\n") - return - default: - continue - } - } + stopChan <- true + t.Logf("\nshouldStop channel received a signal, stopping\n") }) } }