From 001b0962289547cf2d9f45dc85570962faf2fc68 Mon Sep 17 00:00:00 2001
From: Jaime Soriano Pastor <jaime.soriano@elastic.co>
Date: Wed, 26 Jun 2019 18:46:08 +0200
Subject: [PATCH] Add new metricset interface with context (#11981)

Add new reporting and fetch metricset interfaces with a context
that is created from the done channel of the metricset wrappers.
This way the framework provides a context to metricsets that
can be used on fetch time when doing requests to services.

Two metricsets are migrated to these new interfaces.
---
 CHANGELOG-developer.next.asciidoc             |   1 +
 metricbeat/mb/builders.go                     |  13 +-
 metricbeat/mb/mb.go                           |  17 +++
 metricbeat/mb/module/wrapper.go               |  39 +++++-
 metricbeat/mb/testing/data_generator.go       |  21 +++
 metricbeat/mb/testing/modules.go              | 121 +++++++++++-------
 .../module/docker/container/container.go      |   4 +-
 .../container/container_integration_test.go   |   4 +-
 metricbeat/module/docker/event/event.go       |   8 +-
 .../docker/event/event_integration_test.go    |   6 +-
 10 files changed, 169 insertions(+), 65 deletions(-)

diff --git a/CHANGELOG-developer.next.asciidoc b/CHANGELOG-developer.next.asciidoc
index 7b7556e54800..ad2b96476686 100644
--- a/CHANGELOG-developer.next.asciidoc
+++ b/CHANGELOG-developer.next.asciidoc
@@ -41,4 +41,5 @@ The list below covers the major changes between 7.0.0-rc2 and master only.
 - Add new option `IgnoreAllErrors` to `libbeat.common.schema` for skipping fields that failed while converting. {pull}12089[12089]
 - Deprecate setup cmds for `template` and `ilm-policy`. Add new setup cmd for `index-management`. {pull}12132[12132]
 - Use the go-lookslike library for testing in heartbeat. Eventually the mapval package will be replaced with it. {pull}12540[12540]
+- New ReporterV2 interfaces that can receive a context on `Fetch(ctx, reporter)`, or `Run(ctx, reporter)`. {pull}11981[11981]
 - Generate configuration from `mage` for all Beats. {pull}12618[12618]
diff --git a/metricbeat/mb/builders.go b/metricbeat/mb/builders.go
index deb2e43ec63f..04b32f5638aa 100644
--- a/metricbeat/mb/builders.go
+++ b/metricbeat/mb/builders.go
@@ -241,15 +241,24 @@ func mustImplementFetcher(ms MetricSet) error {
 		ifcs = append(ifcs, "ReportingMetricSetV2Error")
 	}
 
+	if _, ok := ms.(ReportingMetricSetV2WithContext); ok {
+		ifcs = append(ifcs, "ReportingMetricSetV2WithContext")
+	}
+
 	if _, ok := ms.(PushMetricSetV2); ok {
 		ifcs = append(ifcs, "PushMetricSetV2")
 	}
+
+	if _, ok := ms.(PushMetricSetV2WithContext); ok {
+		ifcs = append(ifcs, "PushMetricSetV2WithContext")
+	}
+
 	switch len(ifcs) {
 	case 0:
 		return fmt.Errorf("MetricSet '%s/%s' does not implement an event "+
 			"producing interface (EventFetcher, EventsFetcher, "+
-			"ReportingMetricSet, ReportingMetricSetV2, ReportingMetricSetV2Error, PushMetricSet, or "+
-			"PushMetricSetV2)",
+			"ReportingMetricSet, ReportingMetricSetV2, ReportingMetricSetV2Error, ReportingMetricSetV2WithContext"+
+			"PushMetricSet, PushMetricSetV2, or PushMetricSetV2WithContext)",
 			ms.Module().Name(), ms.Name())
 	case 1:
 		return nil
diff --git a/metricbeat/mb/mb.go b/metricbeat/mb/mb.go
index e38655a6907f..d55b8971d4a8 100644
--- a/metricbeat/mb/mb.go
+++ b/metricbeat/mb/mb.go
@@ -22,6 +22,7 @@ to implement Modules and their associated MetricSets.
 package mb
 
 import (
+	"context"
 	"fmt"
 	"net/url"
 	"time"
@@ -208,6 +209,13 @@ type ReportingMetricSetV2Error interface {
 	Fetch(r ReporterV2) error
 }
 
+// ReportingMetricSetV2WithContext is a MetricSet that reports events or errors through the
+// ReporterV2 interface. Fetch is called periodically to collect events.
+type ReportingMetricSetV2WithContext interface {
+	MetricSet
+	Fetch(ctx context.Context, r ReporterV2) error
+}
+
 // PushMetricSetV2 is a MetricSet that pushes events (rather than pulling them
 // periodically via a Fetch callback). Run is invoked to start the event
 // subscription and it should block until the MetricSet is ready to stop or
@@ -217,6 +225,15 @@ type PushMetricSetV2 interface {
 	Run(r PushReporterV2)
 }
 
+// PushMetricSetV2WithContext is a MetricSet that pushes events (rather than pulling them
+// periodically via a Fetch callback). Run is invoked to start the event
+// subscription and it should block until the MetricSet is ready to stop or
+// the context is closed.
+type PushMetricSetV2WithContext interface {
+	MetricSet
+	Run(ctx context.Context, r ReporterV2)
+}
+
 // HostData contains values parsed from the 'host' configuration. Other
 // configuration data like protocols, usernames, and passwords may also be
 // used to construct this HostData data.
diff --git a/metricbeat/mb/module/wrapper.go b/metricbeat/mb/module/wrapper.go
index 245876cd655b..061e7987b811 100644
--- a/metricbeat/mb/module/wrapper.go
+++ b/metricbeat/mb/module/wrapper.go
@@ -18,6 +18,7 @@
 package module
 
 import (
+	"context"
 	"fmt"
 	"math/rand"
 	"sync"
@@ -191,9 +192,11 @@ func (msw *metricSetWrapper) run(done <-chan struct{}, out chan<- beat.Event) {
 		ms.Run(reporter.V1())
 	case mb.PushMetricSetV2:
 		ms.Run(reporter.V2())
+	case mb.PushMetricSetV2WithContext:
+		ms.Run(&channelContext{done}, reporter.V2())
 	case mb.EventFetcher, mb.EventsFetcher,
-		mb.ReportingMetricSet, mb.ReportingMetricSetV2, mb.ReportingMetricSetV2Error:
-		msw.startPeriodicFetching(reporter)
+		mb.ReportingMetricSet, mb.ReportingMetricSetV2, mb.ReportingMetricSetV2Error, mb.ReportingMetricSetV2WithContext:
+		msw.startPeriodicFetching(&channelContext{done}, reporter)
 	default:
 		// Earlier startup stages prevent this from happening.
 		logp.Err("MetricSet '%s/%s' does not implement an event producing interface",
@@ -204,9 +207,9 @@ func (msw *metricSetWrapper) run(done <-chan struct{}, out chan<- beat.Event) {
 // startPeriodicFetching performs an immediate fetch for the MetricSet then it
 // begins a continuous timer scheduled loop to fetch data. To stop the loop the
 // done channel should be closed.
-func (msw *metricSetWrapper) startPeriodicFetching(reporter reporter) {
+func (msw *metricSetWrapper) startPeriodicFetching(ctx context.Context, reporter reporter) {
 	// Fetch immediately.
-	msw.fetch(reporter)
+	msw.fetch(ctx, reporter)
 
 	// Start timer for future fetches.
 	t := time.NewTicker(msw.Module().Config().Period)
@@ -216,7 +219,7 @@ func (msw *metricSetWrapper) startPeriodicFetching(reporter reporter) {
 		case <-reporter.V2().Done():
 			return
 		case <-t.C:
-			msw.fetch(reporter)
+			msw.fetch(ctx, reporter)
 		}
 	}
 }
@@ -224,7 +227,7 @@ func (msw *metricSetWrapper) startPeriodicFetching(reporter reporter) {
 // fetch invokes the appropriate Fetch method for the MetricSet and publishes
 // the result using the publisher client. This method will recover from panics
 // and log a stack track if one occurs.
-func (msw *metricSetWrapper) fetch(reporter reporter) {
+func (msw *metricSetWrapper) fetch(ctx context.Context, reporter reporter) {
 	switch fetcher := msw.MetricSet.(type) {
 	case mb.EventFetcher:
 		msw.singleEventFetch(fetcher, reporter)
@@ -243,6 +246,13 @@ func (msw *metricSetWrapper) fetch(reporter reporter) {
 			reporter.V2().Error(err)
 			logp.Info("Error fetching data for metricset %s.%s: %s", msw.module.Name(), msw.Name(), err)
 		}
+	case mb.ReportingMetricSetV2WithContext:
+		reporter.StartFetchTimer()
+		err := fetcher.Fetch(ctx, reporter.V2())
+		if err != nil {
+			reporter.V2().Error(err)
+			logp.Info("Error fetching data for metricset %s.%s: %s", msw.module.Name(), msw.Name(), err)
+		}
 	default:
 		panic(fmt.Sprintf("unexpected fetcher type for %v", msw))
 	}
@@ -313,6 +323,23 @@ func (r *eventReporter) V1() mb.PushReporter {
 }
 func (r *eventReporter) V2() mb.PushReporterV2 { return reporterV2{r} }
 
+// channelContext implements context.Context by wrapping a channel
+type channelContext struct {
+	done <-chan struct{}
+}
+
+func (r *channelContext) Deadline() (time.Time, bool) { return time.Time{}, false }
+func (r *channelContext) Done() <-chan struct{}       { return r.done }
+func (r *channelContext) Err() error {
+	select {
+	case <-r.done:
+		return context.Canceled
+	default:
+		return nil
+	}
+}
+func (r *channelContext) Value(key interface{}) interface{} { return nil }
+
 // reporterV1 wraps V2 to provide a v1 interface.
 type reporterV1 struct {
 	v2     mb.PushReporterV2
diff --git a/metricbeat/mb/testing/data_generator.go b/metricbeat/mb/testing/data_generator.go
index 1bdb6c69e5d4..1f75f44e9d7b 100644
--- a/metricbeat/mb/testing/data_generator.go
+++ b/metricbeat/mb/testing/data_generator.go
@@ -99,6 +99,12 @@ func WriteEventsReporterV2Error(f mb.ReportingMetricSetV2Error, t testing.TB, pa
 	return WriteEventsReporterV2ErrorCond(f, t, path, nil)
 }
 
+// WriteEventsReporterV2WithContext fetches events and writes the first event to a ./_meta/data.json
+// file.
+func WriteEventsReporterV2WithContext(f mb.ReportingMetricSetV2WithContext, t testing.TB, path string) error {
+	return WriteEventsReporterV2WithContextCond(f, t, path, nil)
+}
+
 // WriteEventsReporterV2Cond fetches events and writes the first event that matches
 // the condition to a file.
 func WriteEventsReporterV2Cond(f mb.ReportingMetricSetV2, t testing.TB, path string, cond func(common.MapStr) bool) error {
@@ -129,6 +135,21 @@ func WriteEventsReporterV2ErrorCond(f mb.ReportingMetricSetV2Error, t testing.TB
 	return writeEvent(events, f, t, path, cond)
 }
 
+// WriteEventsReporterV2WithContextCond fetches events and writes the first event that matches
+// the condition to a file.
+func WriteEventsReporterV2WithContextCond(f mb.ReportingMetricSetV2WithContext, t testing.TB, path string, cond func(common.MapStr) bool) error {
+	if !*dataFlag {
+		t.Skip("skip data generation tests")
+	}
+
+	events, errs := ReportingFetchV2WithContext(f)
+	if len(errs) > 0 {
+		return errs[0]
+	}
+
+	return writeEvent(events, f, t, path, cond)
+}
+
 func writeEvent(events []mb.Event, f mb.MetricSet, t testing.TB, path string, cond func(common.MapStr) bool) error {
 	if len(events) == 0 {
 		return fmt.Errorf("no events were generated")
diff --git a/metricbeat/mb/testing/modules.go b/metricbeat/mb/testing/modules.go
index ec5ebc0f5322..12ff0e9a1add 100644
--- a/metricbeat/mb/testing/modules.go
+++ b/metricbeat/mb/testing/modules.go
@@ -54,6 +54,7 @@ that Metricbeat does it and with the same validations.
 package testing
 
 import (
+	"context"
 	"sync"
 	"testing"
 	"time"
@@ -181,6 +182,19 @@ func NewReportingMetricSetV2Error(t testing.TB, config interface{}) mb.Reporting
 	return reportingMetricSetV2Error
 }
 
+// NewReportingMetricSetV2WithContext returns a new ReportingMetricSetV2WithContext instance. Then
+// you can use ReportingFetchV2 to perform a Fetch operation with the MetricSet.
+func NewReportingMetricSetV2WithContext(t testing.TB, config interface{}) mb.ReportingMetricSetV2WithContext {
+	metricSet := NewMetricSet(t, config)
+
+	reportingMetricSet, ok := metricSet.(mb.ReportingMetricSetV2WithContext)
+	if !ok {
+		t.Fatal("MetricSet does not implement ReportingMetricSetV2WithContext")
+	}
+
+	return reportingMetricSet
+}
+
 // CapturingReporterV2 is a reporter used for testing which stores all events and errors
 type CapturingReporterV2 struct {
 	events []mb.Event
@@ -228,6 +242,17 @@ func ReportingFetchV2Error(metricSet mb.ReportingMetricSetV2Error) ([]mb.Event,
 	return r.events, r.errs
 }
 
+// ReportingFetchV2WithContext runs the given reporting metricset and returns all of the
+// events and errors that occur during that period.
+func ReportingFetchV2WithContext(metricSet mb.ReportingMetricSetV2WithContext) ([]mb.Event, []error) {
+	r := &CapturingReporterV2{}
+	err := metricSet.Fetch(context.Background(), r)
+	if err != nil {
+		r.errs = append(r.errs, err)
+	}
+	return r.events, r.errs
+}
+
 // NewPushMetricSet instantiates a new PushMetricSet using the given
 // configuration. The ModuleFactory and MetricSetFactory are obtained from the
 // global Registry.
@@ -301,7 +326,21 @@ func NewPushMetricSetV2(t testing.TB, config interface{}) mb.PushMetricSetV2 {
 
 	pushMetricSet, ok := metricSet.(mb.PushMetricSetV2)
 	if !ok {
-		t.Fatal("MetricSet does not implement PushMetricSet")
+		t.Fatal("MetricSet does not implement PushMetricSetV2")
+	}
+
+	return pushMetricSet
+}
+
+// NewPushMetricSetV2WithContext instantiates a new PushMetricSetV2WithContext
+// using the given configuration. The ModuleFactory and MetricSetFactory are
+// obtained from the global Registry.
+func NewPushMetricSetV2WithContext(t testing.TB, config interface{}) mb.PushMetricSetV2WithContext {
+	metricSet := NewMetricSet(t, config)
+
+	pushMetricSet, ok := metricSet.(mb.PushMetricSetV2WithContext)
+	if !ok {
+		t.Fatal("MetricSet does not implement PushMetricSetV2WithContext")
 	}
 
 	return pushMetricSet
@@ -310,15 +349,19 @@ func NewPushMetricSetV2(t testing.TB, config interface{}) mb.PushMetricSetV2 {
 // capturingPushReporterV2 stores all the events and errors from a metricset's
 // Run method.
 type capturingPushReporterV2 struct {
-	doneC   chan struct{}
+	context.Context
 	eventsC chan mb.Event
 }
 
+func newCapturingPushReporterV2(ctx context.Context) *capturingPushReporterV2 {
+	return &capturingPushReporterV2{Context: ctx, eventsC: make(chan mb.Event)}
+}
+
 // report writes an event to the output channel and returns true. If the output
 // is closed it returns false.
 func (r *capturingPushReporterV2) report(event mb.Event) bool {
 	select {
-	case <-r.doneC:
+	case <-r.Done():
 		// Publisher is stopped.
 		return false
 	case r.eventsC <- event:
@@ -336,54 +379,42 @@ func (r *capturingPushReporterV2) Error(err error) bool {
 	return r.report(mb.Event{Error: err})
 }
 
-// Done returns the Done channel for this reporter.
-func (r *capturingPushReporterV2) Done() <-chan struct{} {
-	return r.doneC
+func (r *capturingPushReporterV2) capture(waitEvents int) []mb.Event {
+	var events []mb.Event
+	for {
+		select {
+		case <-r.Done():
+			// Timeout
+			return events
+		case e := <-r.eventsC:
+			events = append(events, e)
+			if waitEvents > 0 && len(events) >= waitEvents {
+				return events
+			}
+		}
+	}
 }
 
 // RunPushMetricSetV2 run the given push metricset for the specific amount of
 // time and returns all of the events and errors that occur during that period.
 func RunPushMetricSetV2(timeout time.Duration, waitEvents int, metricSet mb.PushMetricSetV2) []mb.Event {
-	var (
-		r      = &capturingPushReporterV2{doneC: make(chan struct{}), eventsC: make(chan mb.Event)}
-		wg     sync.WaitGroup
-		events []mb.Event
-	)
-	wg.Add(2)
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
 
-	// Producer
-	go func() {
-		defer wg.Done()
-		defer close(r.eventsC)
-		if closer, ok := metricSet.(mb.Closer); ok {
-			defer closer.Close()
-		}
-		metricSet.Run(r)
-	}()
+	r := newCapturingPushReporterV2(ctx)
 
-	// Consumer
-	go func() {
-		defer wg.Done()
-		defer close(r.doneC)
-
-		timer := time.NewTimer(timeout)
-		defer timer.Stop()
-		for {
-			select {
-			case <-timer.C:
-				return
-			case e, ok := <-r.eventsC:
-				if !ok {
-					return
-				}
-				events = append(events, e)
-				if waitEvents > 0 && waitEvents <= len(events) {
-					return
-				}
-			}
-		}
-	}()
+	go metricSet.Run(r)
+	return r.capture(waitEvents)
+}
 
-	wg.Wait()
-	return events
+// RunPushMetricSetV2WithContext run the given push metricset for the specific amount of
+// time and returns all of the events that occur during that period.
+func RunPushMetricSetV2WithContext(timeout time.Duration, waitEvents int, metricSet mb.PushMetricSetV2WithContext) []mb.Event {
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+
+	r := newCapturingPushReporterV2(ctx)
+
+	go metricSet.Run(ctx, r)
+	return r.capture(waitEvents)
 }
diff --git a/metricbeat/module/docker/container/container.go b/metricbeat/module/docker/container/container.go
index 8f84dc3eda2d..940d6e333f1e 100644
--- a/metricbeat/module/docker/container/container.go
+++ b/metricbeat/module/docker/container/container.go
@@ -63,9 +63,9 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
 
 // Fetch returns a list of all containers as events.
 // This is based on https://docs.docker.com/engine/reference/api/docker_remote_api_v1.24/#/list-containers.
-func (m *MetricSet) Fetch(r mb.ReporterV2) error {
+func (m *MetricSet) Fetch(ctx context.Context, r mb.ReporterV2) error {
 	// Fetch a list of all containers.
-	containers, err := m.dockerClient.ContainerList(context.Background(), types.ContainerListOptions{})
+	containers, err := m.dockerClient.ContainerList(ctx, types.ContainerListOptions{})
 	if err != nil {
 		return errors.Wrap(err, "failed to get docker containers list")
 	}
diff --git a/metricbeat/module/docker/container/container_integration_test.go b/metricbeat/module/docker/container/container_integration_test.go
index 2b9c88e1b5ad..8802c30f6c32 100644
--- a/metricbeat/module/docker/container/container_integration_test.go
+++ b/metricbeat/module/docker/container/container_integration_test.go
@@ -26,8 +26,8 @@ import (
 )
 
 func TestData(t *testing.T) {
-	f := mbtest.NewReportingMetricSetV2Error(t, getConfig())
-	if err := mbtest.WriteEventsReporterV2Error(f, t, ""); err != nil {
+	f := mbtest.NewReportingMetricSetV2WithContext(t, getConfig())
+	if err := mbtest.WriteEventsReporterV2WithContext(f, t, ""); err != nil {
 		t.Fatal("write", err)
 	}
 }
diff --git a/metricbeat/module/docker/event/event.go b/metricbeat/module/docker/event/event.go
index cae7ceaa561e..6426ffeb961a 100644
--- a/metricbeat/module/docker/event/event.go
+++ b/metricbeat/module/docker/event/event.go
@@ -76,8 +76,7 @@ func New(base mb.BaseMetricSet) (mb.MetricSet, error) {
 }
 
 // Run listens for docker events and reports them
-func (m *MetricSet) Run(reporter mb.PushReporterV2) {
-	ctx, cancel := context.WithCancel(context.Background())
+func (m *MetricSet) Run(ctx context.Context, reporter mb.ReporterV2) {
 	options := types.EventsOptions{
 		Since: fmt.Sprintf("%d", time.Now().Unix()),
 	}
@@ -100,16 +99,15 @@ func (m *MetricSet) Run(reporter mb.PushReporterV2) {
 				time.Sleep(1 * time.Second)
 				break WATCH
 
-			case <-reporter.Done():
+			case <-ctx.Done():
 				m.logger.Debug("docker", "event watcher stopped")
-				cancel()
 				return
 			}
 		}
 	}
 }
 
-func (m *MetricSet) reportEvent(reporter mb.PushReporterV2, event events.Message) {
+func (m *MetricSet) reportEvent(reporter mb.ReporterV2, event events.Message) {
 	time := time.Unix(event.Time, 0)
 
 	attributes := make(map[string]string, len(event.Actor.Attributes))
diff --git a/metricbeat/module/docker/event/event_integration_test.go b/metricbeat/module/docker/event/event_integration_test.go
index 361cb0e5c95a..3cce1486672c 100644
--- a/metricbeat/module/docker/event/event_integration_test.go
+++ b/metricbeat/module/docker/event/event_integration_test.go
@@ -20,6 +20,7 @@
 package event
 
 import (
+	"context"
 	"io"
 	"os"
 	"testing"
@@ -28,7 +29,6 @@ import (
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/container"
 	"github.com/docker/docker/client"
-	"golang.org/x/net/context"
 
 	"github.com/elastic/beats/auditbeat/core"
 	"github.com/elastic/beats/metricbeat/mb"
@@ -36,11 +36,11 @@ import (
 )
 
 func TestData(t *testing.T) {
-	ms := mbtest.NewPushMetricSetV2(t, getConfig())
+	ms := mbtest.NewPushMetricSetV2WithContext(t, getConfig())
 	var events []mb.Event
 	done := make(chan interface{})
 	go func() {
-		events = mbtest.RunPushMetricSetV2(10*time.Second, 1, ms)
+		events = mbtest.RunPushMetricSetV2WithContext(10*time.Second, 1, ms)
 		close(done)
 	}()