Skip to content

Commit

Permalink
refact: context propagation (apiclient, cticlient...) (#3477)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Feb 21, 2025
1 parent 8da6a4d commit a3187d6
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 68 deletions.
17 changes: 8 additions & 9 deletions pkg/acquisition/modules/appsec/appsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,16 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe

// let's load the associated appsec_config:
if w.config.AppsecConfigPath != "" {
err := appsecCfg.LoadByPath(w.config.AppsecConfigPath)
if err != nil {
if err = appsecCfg.LoadByPath(w.config.AppsecConfigPath); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
} else if w.config.AppsecConfig != "" {
err := appsecCfg.Load(w.config.AppsecConfig)
if err != nil {
if err = appsecCfg.Load(w.config.AppsecConfig); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
} else if len(w.config.AppsecConfigs) > 0 {
for _, appsecConfig := range w.config.AppsecConfigs {
err := appsecCfg.Load(appsecConfig)
if err != nil {
if err = appsecCfg.Load(appsecConfig); err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
}
Expand All @@ -233,6 +230,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, metricsLe
if err != nil {
return fmt.Errorf("unable to get authenticated LAPI client: %w", err)
}

w.appsecAllowlistClient = allowlists.NewAppsecAllowlist(w.apiClient, w.logger)

for nbRoutine := range w.config.Routines {
Expand Down Expand Up @@ -371,12 +369,12 @@ func (w *AppsecSource) Dump() interface{} {
return w
}

func (w *AppsecSource) IsAuth(apiKey string) bool {
func (w *AppsecSource) IsAuth(ctx context.Context, apiKey string) bool {
client := &http.Client{
Timeout: 200 * time.Millisecond,
}

req, err := http.NewRequest(http.MethodHead, w.lapiURL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodHead, w.lapiURL, nil)
if err != nil {
log.Errorf("Error creating request: %s", err)
return false
Expand All @@ -397,6 +395,7 @@ func (w *AppsecSource) IsAuth(apiKey string) bool {

// should this be in the runner ?
func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path)

apiKey := r.Header.Get(appsec.APIKeyHeaderName)
Expand All @@ -413,7 +412,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
expiration, exists := w.AuthCache.Get(apiKey)
// if the apiKey is not in cache or has expired, just recheck the auth
if !exists || time.Now().After(expiration) {
if !w.IsAuth(apiKey) {
if !w.IsAuth(ctx, apiKey) {
rw.WriteHeader(http.StatusUnauthorized)
w.logger.Errorf("Unauthorized request from '%s' (real IP = %s)", remoteIP, clientIP)

Expand Down
46 changes: 35 additions & 11 deletions pkg/acquisition/modules/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ basic_auth:
}

func TestStreamingAcquisitionBasicAuth(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -306,7 +307,7 @@ basic_auth:
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
require.NoError(t, err)
req.SetBasicAuth("test", "WrongPassword")

Expand All @@ -321,6 +322,7 @@ basic_auth:
}

func TestStreamingAcquisitionBadHeaders(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -334,7 +336,7 @@ headers:

client := &http.Client{}

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
require.NoError(t, err)

req.Header.Add("Key", "wrong")
Expand All @@ -349,6 +351,7 @@ headers:
}

func TestStreamingAcquisitionMaxBodySize(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -362,7 +365,7 @@ max_body_size: 5`), 0)
time.Sleep(1 * time.Second)

client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand All @@ -378,6 +381,7 @@ max_body_size: 5`), 0)
}

func TestStreamingAcquisitionSuccess(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -388,13 +392,14 @@ headers:
key: test`), 2)

time.Sleep(1 * time.Second)

rawEvt := `{"test": "test"}`

errChan := make(chan error)
go assertEvents(out, []string{rawEvt}, errChan)

client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand All @@ -414,6 +419,7 @@ headers:
}

func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -430,9 +436,10 @@ custom_headers:

rawEvt := `{"test": "test"}`
errChan := make(chan error)

go assertEvents(out, []string{rawEvt}, errChan)

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand Down Expand Up @@ -463,9 +470,11 @@ func (sr *slowReader) Read(p []byte) (int, error) {
if sr.index >= len(sr.body) {
return 0, io.EOF
}

time.Sleep(sr.delay) // Simulate a delay in reading
n := copy(p, sr.body[sr.index:])
sr.index += n

return n, nil
}

Expand All @@ -492,10 +501,12 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw)
return
}

if evt.Line.Src != "127.0.0.1" {
errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src)
return
}

if evt.Line.Module != "http" {
errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module)
return
Expand All @@ -505,6 +516,7 @@ func assertEvents(out chan types.Event, expected []string, errChan chan error) {
}

func TestStreamingAcquisitionTimeout(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
_, _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -522,7 +534,7 @@ timeout: 1s`), 0)
body: []byte(`{"test": "delayed_payload"}`),
}

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand Down Expand Up @@ -566,6 +578,7 @@ tls:
}

func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand Down Expand Up @@ -599,9 +612,10 @@ tls:

rawEvt := `{"test": "test"}`
errChan := make(chan error)

go assertEvents(out, []string{rawEvt}, errChan)

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand All @@ -622,6 +636,7 @@ tls:
}

func TestStreamingAcquisitionMTLS(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand Down Expand Up @@ -658,9 +673,10 @@ tls:

rawEvt := `{"test": "test"}`
errChan := make(chan error)

go assertEvents(out, []string{rawEvt}, errChan)

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
require.NoError(t, err)

resp, err := client.Do(req)
Expand All @@ -680,6 +696,7 @@ tls:
}

func TestStreamingAcquisitionGzipData(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -693,6 +710,7 @@ headers:

rawEvt := `{"test": "test"}`
errChan := make(chan error)

go assertEvents(out, []string{rawEvt, rawEvt}, errChan)

var b strings.Builder
Expand All @@ -709,7 +727,7 @@ headers:

// send gzipped compressed data
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
require.NoError(t, err)

req.Header.Add("Key", "test")
Expand All @@ -733,6 +751,7 @@ headers:
}

func TestStreamingAcquisitionNDJson(t *testing.T) {
ctx := t.Context()
h := &HTTPSource{}
out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(`
source: http
Expand All @@ -743,13 +762,14 @@ headers:
key: test`), 2)

time.Sleep(1 * time.Second)
rawEvt := `{"test": "test"}`

rawEvt := `{"test": "test"}`
errChan := make(chan error)

go assertEvents(out, []string{rawEvt, rawEvt}, errChan)

client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))

require.NoError(t, err)

Expand All @@ -776,10 +796,13 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
require.NoError(t, err)

isExist := false

for _, metricFamily := range promMetrics {
if metricFamily.GetName() == "cs_httpsource_hits_total" {
isExist = true

assert.Len(t, metricFamily.GetMetric(), 1)

for _, metric := range metricFamily.GetMetric() {
assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001)
labels := metric.GetLabel()
Expand All @@ -791,6 +814,7 @@ func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.
}
}
}

if !isExist && expected > 0 {
t.Fatalf("expected metric cs_httpsource_hits_total not found")
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/apiclient/alerts_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type AlertsDeleteOpts struct {
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)

req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
req, err := s.client.NewRequestWithContext(ctx, http.MethodPost, u, &alerts)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
URI = fmt.Sprintf("%s?%s", URI, params.Encode())
}

req, err := s.client.NewRequest(http.MethodGet, URI, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, URI, nil)
if err != nil {
return nil, nil, fmt.Errorf("building request: %w", err)
}
Expand All @@ -102,7 +102,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod

u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())

req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}
Expand All @@ -120,7 +120,7 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)

req, err := s.client.NewRequest(http.MethodDelete, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
if err != nil {
return nil, nil, err
}
Expand All @@ -138,7 +138,7 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)

req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiclient/allowlists_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (s *AllowlistsService) List(ctx context.Context, opts AllowlistListOpts) (*

u += "?" + params.Encode()

req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -58,7 +58,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist

log.Debugf("GET %s", u)

req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
Expand All @@ -76,7 +76,7 @@ func (s *AllowlistsService) Get(ctx context.Context, name string, opts Allowlist
func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string) (bool, *Response, error) {
u := s.client.URLPrefix + "/allowlists/check/" + value

req, err := s.client.NewRequest(http.MethodHead, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodHead, u, nil)
if err != nil {
return false, nil, err
}
Expand All @@ -94,7 +94,7 @@ func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string
func (s *AllowlistsService) CheckIfAllowlistedWithReason(ctx context.Context, value string) (*models.CheckAllowlistResponse, *Response, error) {
u := s.client.URLPrefix + "/allowlists/check/" + value

req, err := s.client.NewRequest(http.MethodGet, u, nil)
req, err := s.client.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
Expand Down
Loading

0 comments on commit a3187d6

Please sign in to comment.