diff --git a/server.go b/server.go index 7c23a31..64d4e88 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ type Server struct { LMTP bool Domain string + MaxConnections int MaxRecipients int MaxMessageBytes int64 MaxLineLength int @@ -127,17 +128,7 @@ func (s *Server) Serve(l net.Listener) error { } func (s *Server) handleConn(c *Conn) error { - s.locker.Lock() - s.conns[c] = struct{}{} - s.locker.Unlock() - - defer func() { - c.Close() - - s.locker.Lock() - delete(s.conns, c) - s.locker.Unlock() - }() + defer c.Close() if tlsConn, ok := c.conn.(*tls.Conn); ok { if d := s.ReadTimeout; d != 0 { @@ -151,6 +142,29 @@ func (s *Server) handleConn(c *Conn) error { } } + // register connection + maxConnsExceeded := false + s.locker.Lock() + if s.MaxConnections > 0 && len(s.conns) >= s.MaxConnections { + maxConnsExceeded = true + } else { + s.conns[c] = struct{}{} + } + s.locker.Unlock() + + // limit connections + if maxConnsExceeded { + c.writeResponse(421, EnhancedCode{4, 4, 5}, "Too busy. Try again later.") + return nil + } + + // unregister connection + defer func() { + s.locker.Lock() + delete(s.conns, c) + s.locker.Unlock() + }() + c.greet() for { diff --git a/server_test.go b/server_test.go index 206d336..f5cd3ca 100644 --- a/server_test.go +++ b/server_test.go @@ -1514,3 +1514,49 @@ func TestServerDSNwithSMTPUTF8(t *testing.T) { t.Fatal("Invalid ORCPT address:", val) } } + +func TestServer_MaxConnections(t *testing.T) { + cases := []struct { + name string + maxConnections int + expected string + }{ + // 0 = unlimited; all connections should be accepted + {name: "MaxConnections set to 0", maxConnections: 0, expected: "220 localhost ESMTP Service Ready"}, + // 1 = only one connection is allowed; the second connection should be rejected + {name: "MaxConnections set to 1", maxConnections: 1, expected: "421 4.4.5 Too busy. Try again later."}, + // 2 = two connections are allowed; the second connection should be accepted + {name: "MaxConnections set to 2", maxConnections: 2, expected: "220 localhost ESMTP Service Ready"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // create server with limited allowed connections + _, s, c, scanner1 := testServer(t, func(s *smtp.Server) { + s.MaxConnections = tc.maxConnections + }) + defer s.Close() + + // there is already be one connection registered + // and we can read the greeting from it (see testServerGreeted()) + scanner1.Scan() + if scanner1.Text() != "220 localhost ESMTP Service Ready" { + t.Fatal("Invalid first greeting:", scanner1.Text()) + } + + // now we create a second connection + c2, err := net.Dial("tcp", c.RemoteAddr().String()) + if err != nil { + t.Fatal("Error creating second connection:", err) + } + + // we should get an appropriate greeting now + scanner2 := bufio.NewScanner(c2) + scanner2.Scan() + if scanner2.Text() != tc.expected { + t.Fatal("Invalid second greeting:", scanner2.Text()) + } + }) + } + +}