From 862c0915337427ce23a9febfb3dfab422e3b623d Mon Sep 17 00:00:00 2001
From: Joel Takvorian <jtakvori@redhat.com>
Date: Wed, 22 Nov 2023 13:08:09 +0100
Subject: [PATCH] NETOBSERV-1102: fine-tuning http server settings

Similar to https://github.com/netobserv/network-observability-console-plugin/pull/428
---
 cmd/flowlogs-pipeline/main.go      |  3 +-
 pkg/operational/health.go          | 35 ++++++-------
 pkg/pipeline/health_test.go        |  2 +-
 pkg/prometheus/prom_server.go      |  6 ++-
 pkg/prometheus/prom_server_test.go | 81 ++++++++++++++++++++++++++++++
 pkg/server/common.go               | 46 +++++++++++++++++
 6 files changed, 148 insertions(+), 25 deletions(-)
 create mode 100644 pkg/prometheus/prom_server_test.go
 create mode 100644 pkg/server/common.go

diff --git a/cmd/flowlogs-pipeline/main.go b/cmd/flowlogs-pipeline/main.go
index a2fa5246e..f24e13589 100644
--- a/cmd/flowlogs-pipeline/main.go
+++ b/cmd/flowlogs-pipeline/main.go
@@ -199,7 +199,7 @@ func run() {
 	}
 
 	// Start health report server
-	operational.NewHealthServer(&opts, mainPipeline.IsAlive, mainPipeline.IsReady)
+	healthServer := operational.NewHealthServer(&opts, mainPipeline.IsAlive, mainPipeline.IsReady)
 
 	// Starts the flows pipeline
 	mainPipeline.Run()
@@ -207,6 +207,7 @@ func run() {
 	if promServer != nil {
 		_ = promServer.Shutdown(context.Background())
 	}
+	_ = healthServer.Shutdown(context.Background())
 
 	// Give all threads a chance to exit and then exit the process
 	time.Sleep(time.Second)
diff --git a/pkg/operational/health.go b/pkg/operational/health.go
index a252ebc46..c4510a695 100644
--- a/pkg/operational/health.go
+++ b/pkg/operational/health.go
@@ -24,35 +24,28 @@ import (
 
 	"github.com/heptiolabs/healthcheck"
 	"github.com/netobserv/flowlogs-pipeline/pkg/config"
+	"github.com/netobserv/flowlogs-pipeline/pkg/server"
 	log "github.com/sirupsen/logrus"
 )
 
-type Server struct {
-	handler healthcheck.Handler
-	Address string
-}
-
-func (hs *Server) Serve() {
-	for {
-		err := http.ListenAndServe(hs.Address, hs.handler)
-		log.Errorf("http.ListenAndServe error %v", err)
-		time.Sleep(60 * time.Second)
-	}
-}
-
-func NewHealthServer(opts *config.Options, isAlive healthcheck.Check, isReady healthcheck.Check) *Server {
-
+func NewHealthServer(opts *config.Options, isAlive healthcheck.Check, isReady healthcheck.Check) *http.Server {
 	handler := healthcheck.NewHandler()
 	address := net.JoinHostPort(opts.Health.Address, opts.Health.Port)
 	handler.AddLivenessCheck("PipelineCheck", isAlive)
 	handler.AddReadinessCheck("PipelineCheck", isReady)
 
-	server := &Server{
-		handler: handler,
-		Address: address,
-	}
-
-	go server.Serve()
+	server := server.Default(&http.Server{
+		Handler: handler,
+		Addr:    address,
+	})
+
+	go func() {
+		for {
+			err := server.ListenAndServe()
+			log.Errorf("http.ListenAndServe error %v", err)
+			time.Sleep(60 * time.Second)
+		}
+	}()
 
 	return server
 }
diff --git a/pkg/pipeline/health_test.go b/pkg/pipeline/health_test.go
index 54bd6a649..af16c42ba 100644
--- a/pkg/pipeline/health_test.go
+++ b/pkg/pipeline/health_test.go
@@ -58,7 +58,7 @@ func TestNewHealthServer(t *testing.T) {
 			expectedAddr := fmt.Sprintf("%s:%s", opts.Health.Address, opts.Health.Port)
 			server := operational.NewHealthServer(&opts, tt.args.pipeline.IsAlive, tt.args.pipeline.IsReady)
 			require.NotNil(t, server)
-			require.Equal(t, expectedAddr, server.Address)
+			require.Equal(t, expectedAddr, server.Addr)
 
 			client := &http.Client{}
 
diff --git a/pkg/prometheus/prom_server.go b/pkg/prometheus/prom_server.go
index d40000bc0..19a226583 100644
--- a/pkg/prometheus/prom_server.go
+++ b/pkg/prometheus/prom_server.go
@@ -24,6 +24,7 @@ import (
 
 	"github.com/netobserv/flowlogs-pipeline/pkg/api"
 	"github.com/netobserv/flowlogs-pipeline/pkg/config"
+	"github.com/netobserv/flowlogs-pipeline/pkg/server"
 	prom "github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/sirupsen/logrus"
@@ -63,7 +64,7 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht
 	addr := fmt.Sprintf("%s:%v", conn.Address, port)
 	plog.Infof("StartServerAsync: addr = %s", addr)
 
-	httpServer := http.Server{
+	httpServer := &http.Server{
 		Addr: addr,
 		// TLS clients must use TLS 1.2 or higher
 		TLSConfig: &tls.Config{
@@ -79,6 +80,7 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht
 		mux.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{}))
 	}
 	httpServer.Handler = mux
+	httpServer = server.Default(httpServer)
 
 	go func() {
 		var err error
@@ -92,5 +94,5 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht
 		}
 	}()
 
-	return &httpServer
+	return httpServer
 }
diff --git a/pkg/prometheus/prom_server_test.go b/pkg/prometheus/prom_server_test.go
new file mode 100644
index 000000000..e413e7f85
--- /dev/null
+++ b/pkg/prometheus/prom_server_test.go
@@ -0,0 +1,81 @@
+package prometheus
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/netobserv/flowlogs-pipeline/pkg/config"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestStartPromServer(t *testing.T) {
+	srv := InitializePrometheus(&config.MetricsSettings{})
+
+	serverURL := "http://0.0.0.0:9090"
+	t.Logf("Started test http server: %v", serverURL)
+
+	httpClient := &http.Client{}
+
+	// wait for our test http server to come up
+	checkHTTPReady(httpClient, serverURL)
+
+	r, err := http.NewRequest("GET", serverURL+"/metrics", nil)
+	require.NoError(t, err)
+
+	resp, err := httpClient.Do(r)
+	require.NoError(t, err)
+	defer resp.Body.Close()
+
+	bodyBytes, err := io.ReadAll(resp.Body)
+	require.NoError(t, err)
+
+	bodyString := string(bodyBytes)
+	require.Equal(t, http.StatusOK, resp.StatusCode)
+	require.Contains(t, bodyString, "go_gc_duration_seconds")
+
+	_ = srv.Shutdown(context.Background())
+}
+
+func TestStartPromServer_HeadersLimit(t *testing.T) {
+	srv := InitializePrometheus(&config.MetricsSettings{})
+
+	serverURL := "http://0.0.0.0:9090"
+	t.Logf("Started test http server: %v", serverURL)
+
+	httpClient := &http.Client{}
+
+	// wait for our test http server to come up
+	checkHTTPReady(httpClient, serverURL)
+
+	r, err := http.NewRequest("GET", serverURL+"/metrics", nil)
+	require.NoError(t, err)
+
+	// Set many headers
+	oneKBString := strings.Repeat(".", 1024)
+	for i := 0; i < 1025; i++ {
+		r.Header.Set(fmt.Sprintf("test-header-%d", i), oneKBString)
+	}
+
+	resp, err := httpClient.Do(r)
+	require.NoError(t, err)
+	defer resp.Body.Close()
+	assert.Equal(t, http.StatusRequestHeaderFieldsTooLarge, resp.StatusCode)
+
+	_ = srv.Shutdown(context.Background())
+}
+
+func checkHTTPReady(httpClient *http.Client, url string) {
+	for i := 0; i < 60; i++ {
+		if r, err := httpClient.Get(url); err == nil {
+			r.Body.Close()
+			break
+		}
+		time.Sleep(time.Second)
+	}
+}
diff --git a/pkg/server/common.go b/pkg/server/common.go
new file mode 100644
index 000000000..92fdb7fb5
--- /dev/null
+++ b/pkg/server/common.go
@@ -0,0 +1,46 @@
+package server
+
+import (
+	"crypto/tls"
+	"net/http"
+	"time"
+
+	"github.com/sirupsen/logrus"
+)
+
+var slog = logrus.WithField("module", "server")
+
+func Default(srv *http.Server) *http.Server {
+	// defaults taken from https://bruinsslot.jp/post/go-secure-webserver/ can be overriden by caller
+	if srv.Handler != nil {
+		// No more than 2MB body
+		srv.Handler = http.MaxBytesHandler(srv.Handler, 2<<20)
+	} else {
+		slog.Warnf("Handler not yet set on server while securing defaults. Make sure a MaxByte middleware is used.")
+	}
+	if srv.ReadTimeout == 0 {
+		srv.ReadTimeout = 10 * time.Second
+	}
+	if srv.ReadHeaderTimeout == 0 {
+		srv.ReadHeaderTimeout = 5 * time.Second
+	}
+	if srv.WriteTimeout == 0 {
+		srv.WriteTimeout = 10 * time.Second
+	}
+	if srv.IdleTimeout == 0 {
+		srv.IdleTimeout = 120 * time.Second
+	}
+	if srv.MaxHeaderBytes == 0 {
+		srv.MaxHeaderBytes = 1 << 20 // 1MB
+	}
+	if srv.TLSConfig == nil {
+		srv.TLSConfig = &tls.Config{}
+	}
+	if srv.TLSConfig.MinVersion == 0 {
+		srv.TLSConfig.MinVersion = tls.VersionTLS13
+	}
+	// Disable http/2
+	srv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0)
+
+	return srv
+}