From 7af03b7ac08bd33b98cffee1ab7d39c11bfbb851 Mon Sep 17 00:00:00 2001 From: Rahul Gupta Date: Wed, 13 Oct 2021 12:24:04 +0530 Subject: [PATCH] feat: rest infrastructure --- cmd/collectors/rest/rest.go | 312 ++++++++++++++++++++++++++++++ cmd/collectors/rest/templating.go | 119 ++++++++++++ cmd/poller/poller.go | 1 + cmd/tools/rest/client.go | 99 ++++++++-- cmd/tools/rest/rest.go | 51 ++--- cmd/tools/rest/swag.go | 7 - conf/rest/9.7.0/disk.yaml | 37 ++++ conf/rest/default.yaml | 8 + pkg/matrix/metric.go | 5 + pkg/matrix/metric_uint8.go | 10 + pkg/util/util.go | 32 +++ 11 files changed, 626 insertions(+), 55 deletions(-) create mode 100644 cmd/collectors/rest/rest.go create mode 100644 cmd/collectors/rest/templating.go create mode 100644 conf/rest/9.7.0/disk.yaml create mode 100644 conf/rest/default.yaml diff --git a/cmd/collectors/rest/rest.go b/cmd/collectors/rest/rest.go new file mode 100644 index 000000000..c3f21aaae --- /dev/null +++ b/cmd/collectors/rest/rest.go @@ -0,0 +1,312 @@ +package rest + +import ( + "encoding/json" + "fmt" + "github.com/tidwall/gjson" + "goharvest2/cmd/poller/collector" + "goharvest2/cmd/poller/plugin" + "goharvest2/cmd/tools/rest" + "goharvest2/pkg/conf" + "goharvest2/pkg/errors" + "goharvest2/pkg/matrix" + "goharvest2/pkg/tree/node" + "goharvest2/pkg/util" + "os" + "strconv" + "strings" + "time" +) + +type Rest struct { + *collector.AbstractCollector + client *rest.Client + apiPath string + instanceKeys []string + instanceLabels map[string]string + counters map[string]string + fields []string + misses []string + returnTimeOut string +} + +func init() { + plugin.RegisterModule(Rest{}) +} + +func (Rest) HarvestModule() plugin.ModuleInfo { + return plugin.ModuleInfo{ + ID: "harvest.collector.rest", + New: func() plugin.Module { return new(Rest) }, + } +} + +func (r *Rest) Init(a *collector.AbstractCollector) error { + + var err error + + r.AbstractCollector = a + if err = collector.Init(r); err != nil { + return err + } + + if r.client, err = r.getClient(a, a.Params); err != nil { + return err + } + + if err = r.client.Init(5); err != nil { + return err + } + + r.Logger.Info().Msgf("connected to %s: %s", r.client.ClusterName(), r.client.Info()) + + r.Matrix.SetGlobalLabel("cluster", r.client.ClusterName()) + + if err = r.initCache(r.getTemplateFn(), r.client.Version()); err != nil { + return err + } + r.Logger.Info().Msgf("initialized cache with %d metrics", len(r.Matrix.GetMetrics())) + return nil +} + +func (r *Rest) getClient(a *collector.AbstractCollector, config *node.Node) (*rest.Client, error) { + var ( + poller *conf.Poller + addr string + err error + client *rest.Client + ) + + opt := a.GetOptions() + if poller, err = conf.GetPoller2(opt.Config, opt.Poller); err != nil { + r.Logger.Error().Stack().Err(err).Str("poller", opt.Poller).Msgf("") + return nil, err + } + if addr = util.Value(poller.Addr, ""); addr == "" { + r.Logger.Error().Stack().Str("poller", opt.Poller).Str("addr", addr).Msgf("Invalid address") + return nil, errors.New(errors.MISSING_PARAM, "addr") + } + + timeout := rest.DefaultTimeout + + if t, err := strconv.Atoi(config.GetChildContentS("client_timeout")); err == nil { + timeout = time.Duration(t) * time.Second + } else { + // default timeout + timeout = rest.DefaultTimeout + } + + if client, err = rest.New(poller, timeout); err != nil { + fmt.Printf("error creating new client %+v\n", err) + os.Exit(1) + } + + return client, err +} + +func (r *Rest) getTemplateFn() string { + var fn string + objects := r.Params.GetChildS("objects") + if objects != nil { + fn = objects.GetChildContentS(r.Object) + } + return fn +} + +// Returns a slice of keys in dot notation from json +func getFieldName(source string, parent string) []string { + res := make([]string, 0) + var arr map[string]gjson.Result + if gjson.Parse(source).IsObject() { + arr = gjson.Parse(source).Map() + } else { + return []string{parent} + } + for key, val := range arr { + var temp []string + if parent == "" { + temp = getFieldName(val.Raw, key) + } else { + temp = getFieldName(val.Raw, parent+"."+key) + } + res = append(res, temp...) + } + return res +} + +func (r *Rest) PollData() (*matrix.Matrix, error) { + + var ( + content []byte + count uint64 + apiD, parseD time.Duration + startTime time.Time + err error + records []interface{} + ) + + r.Logger.Info().Msgf("starting data poll") + r.Matrix.Reset() + + startTime = time.Now() + // Check if fields are set for rest call + if len(r.fields) > 0 { + href := rest.BuildHref(r.apiPath, strings.Join(r.fields[:], ","), nil, "", "", "", r.returnTimeOut) + r.Logger.Info().Msgf("rest end point [%s]", href) + err = rest.FetchData(r.client, href, &records) + if err != nil { + r.Logger.Error().Stack().Err(err).Msgf("") + return nil, err + } + } else { + href := rest.BuildHref(r.apiPath, "*", nil, "", "", "", r.returnTimeOut) + r.Logger.Info().Msgf("rest end point [%s]", href) + err = rest.FetchData(r.client, href, &records) + if err != nil { + r.Logger.Error().Stack().Err(err).Msgf("") + return nil, err + } + all := rest.Pagination{ + Records: records, + NumRecords: len(records), + } + c, err := json.Marshal(all) + if err != nil { + r.Logger.Error().Stack().Err(err).Msgf("") + return nil, err + } else { + results := gjson.GetBytes(c, "records") + if len(results.String()) > 0 { + // fetch first record from json + firstValue := results.Get("1") + res := getFieldName(firstValue.String(), "") + var searchKeys []string + for k := range r.counters { + searchKeys = append(searchKeys, k) + } + // find keys from rest json response which matches counters defined in templates + matches, misses := util.Intersection(res, searchKeys) + for _, param := range matches { + r.fields = append(r.fields, param.(string)) + } + for _, param := range misses { + r.misses = append(r.misses, param.(string)) + } + } + } + } + + if len(r.misses) > 0 { + r.Logger.Warn(). + Str("Missing Counters", strings.Join(r.misses[:], ",")). + Str("ApiPath", r.apiPath). + Msg("Mis configured counters") + } + + all := rest.Pagination{ + Records: records, + NumRecords: len(records), + } + apiD = time.Since(startTime) + + content, err = json.Marshal(all) + if err != nil { + r.Logger.Error().Stack().Err(err).Msgf("") + } + + startTime = time.Now() + if !gjson.ValidBytes(content) { + return nil, fmt.Errorf("json is not valid for: %s", r.apiPath) + } + parseD = time.Since(startTime) + + results := gjson.GetManyBytes(content, "num_records", "records") + numRecords := results[0] + if numRecords.Int() == 0 { + return nil, errors.New(errors.ERR_NO_INSTANCE, "no "+r.Object+" instances on cluster") + } + + r.Logger.Debug().Msgf("extracted %d [%s] instances", numRecords, r.Object) + + results[1].ForEach(func(key, instanceData gjson.Result) bool { + var ( + instanceKey string + instance *matrix.Instance + ) + + if !instanceData.IsObject() { + r.Logger.Warn().Str("type", instanceData.Type.String()).Msg("skip instance") + return true + } + + // extract instance key(s) + for _, k := range r.instanceKeys { + value := instanceData.Get(k) + if value.Exists() { + instanceKey += value.String() + } else { + r.Logger.Warn().Str("key", k).Msg("skip instance, missing key") + break + } + } + + if instanceKey == "" { + return true + } + + if instance = r.Matrix.GetInstance(instanceKey); instance == nil { + if instance, err = r.Matrix.NewInstance(instanceKey); err != nil { + r.Logger.Error().Msgf("NewInstance [key=%s]: %v", instanceKey, err) + return true + } + } + + for label, display := range r.instanceLabels { + value := instanceData.Get(label) + if value.Exists() { + instance.SetLabel(display, value.String()) + count++ + } + } + + for key, metric := range r.Matrix.GetMetrics() { + + if metric.GetProperty() == "etl.bool" { + b := instanceData.Get(key) + if b.Exists() { + if err = metric.SetValueBool(instance, b.Bool()); err != nil { + r.Logger.Error().Err(err).Str("key", key).Msg("SetValueBool metric") + } + count++ + } + } else if metric.GetProperty() == "etl.float" { + f := instanceData.Get(key) + if f.Exists() { + if err = metric.SetValueFloat64(instance, f.Float()); err != nil { + r.Logger.Error().Err(err).Str("key", key).Msg("SetValueFloat64 metric") + } + count++ + } + } + } + return true + }) + + r.Logger.Info(). + Uint64("dataPoints", count). + Str("apiTime", apiD.String()). + Str("parseTime", parseD.String()). + Msg("Collected") + + _ = r.Metadata.LazySetValueInt64("api_time", "data", apiD.Microseconds()) + _ = r.Metadata.LazySetValueInt64("parse_time", "data", parseD.Microseconds()) + _ = r.Metadata.LazySetValueUint64("count", "data", count) + r.AddCollectCount(count) + + return r.Matrix, nil +} + +// Interface guards +var ( + _ collector.Collector = (*Rest)(nil) +) diff --git a/cmd/collectors/rest/templating.go b/cmd/collectors/rest/templating.go new file mode 100644 index 000000000..fea32e339 --- /dev/null +++ b/cmd/collectors/rest/templating.go @@ -0,0 +1,119 @@ +package rest + +import ( + "goharvest2/pkg/errors" + "goharvest2/pkg/matrix" + "goharvest2/pkg/tree/node" + "strings" +) + +func (r *Rest) initCache(templateFn string, version [3]int) error { + + var ( + template, counters *node.Node + display, name, kind string + metr matrix.Metric + err error + ) + + // import template + if template, err = r.ImportSubTemplate("", templateFn, version); err != nil { + return err + } + + r.Logger.Info().Msg("imported subtemplate") + r.Params.Union(template) + + if x := r.Params.GetChildContentS("object"); x != "" { + r.Object = x + } else { + r.Object = strings.ToLower(r.Object) + } + r.Matrix.Object = r.Object + + if e := r.Params.GetChildS("export_options"); e != nil { + r.Matrix.SetExportOptions(e) + } + + if r.apiPath = r.Params.GetChildContentS("query"); r.apiPath == "" { + return errors.New(errors.MISSING_PARAM, "query") + } + + // create metric cache + if counters = r.Params.GetChildS("counters"); counters == nil { + return errors.New(errors.MISSING_PARAM, "counters") + } + + // default value for ONTAP is 15 sec + if returnTimeout := r.Params.GetChildContentS("return_timeout"); returnTimeout != "" { + r.returnTimeOut = returnTimeout + } + + r.instanceKeys = make([]string, 0) + r.instanceLabels = make(map[string]string) + r.counters = make(map[string]string) + + for _, c := range counters.GetAllChildContentS() { + name, display, kind = parseMetric(c) + r.Logger.Debug().Msgf("extracted [%s] (%s) (%s)", kind, name, display) + r.counters[name] = display + switch kind { + case "key": + r.instanceLabels[name] = display + r.instanceKeys = append(r.instanceKeys, name) + case "label": + r.instanceLabels[name] = display + case "bool": + if metr, err = r.Matrix.NewMetricUint8(name); err != nil { + r.Logger.Error().Msgf("NewMetricUint8 [%s]: %v", name, err) + return err + } + metr.SetName(display) + metr.SetProperty("etl.bool") // to distinct from internally generated metrics, e.g. from plugins + case "float": + if metr, err = r.Matrix.NewMetricFloat64(name); err != nil { + r.Logger.Error().Msgf("NewMetricFloat64 [%s]: %v", name, err) + return err + } + metr.SetName(display) + metr.SetProperty("etl.float") + } + } + + r.Logger.Info().Msgf("extracted instance keys: %v", r.instanceKeys) + r.Logger.Info().Msgf("initialized metric cache with %d metrics and %d labels", len(r.Matrix.GetMetrics()), len(r.instanceLabels)) + + if len(r.Matrix.GetMetrics()) == 0 && r.Params.GetChildContentS("collect_only_labels") != "true" { + return errors.New(errors.ERR_NO_METRIC, "failed to parse numeric metrics") + } + return nil + +} + +func parseMetric(rawName string) (string, string, string) { + var ( + name, display string + values []string + ) + if values = strings.SplitN(rawName, "=>", 2); len(values) == 2 { + name = strings.TrimSpace(values[0]) + display = strings.TrimSpace(values[1]) + } else { + name = rawName + display = strings.ReplaceAll(rawName, ".", "_") + } + + if strings.HasPrefix(name, "^^") { + return strings.TrimPrefix(name, "^^"), strings.TrimPrefix(display, "^^"), "key" + } + + if strings.HasPrefix(name, "^") { + return strings.TrimPrefix(name, "^"), strings.TrimPrefix(display, "^"), "label" + } + + if strings.HasPrefix(name, "?") { + return strings.TrimPrefix(name, "?"), strings.TrimPrefix(display, "?"), "bool" + } + + return name, display, "float" +} diff --git a/cmd/poller/poller.go b/cmd/poller/poller.go index 317eed6dd..18a964afd 100644 --- a/cmd/poller/poller.go +++ b/cmd/poller/poller.go @@ -26,6 +26,7 @@ package main import ( "fmt" "github.com/spf13/cobra" + _ "goharvest2/cmd/collectors/rest" _ "goharvest2/cmd/collectors/simple" _ "goharvest2/cmd/collectors/unix" _ "goharvest2/cmd/collectors/zapi/collector" diff --git a/cmd/tools/rest/client.go b/cmd/tools/rest/client.go index 6811cef47..9f9a062f1 100644 --- a/cmd/tools/rest/client.go +++ b/cmd/tools/rest/client.go @@ -5,10 +5,13 @@ package rest import ( "bytes" "crypto/tls" + "encoding/json" "fmt" + "github.com/tidwall/gjson" "goharvest2/pkg/conf" "goharvest2/pkg/errors" "goharvest2/pkg/logging" + "goharvest2/pkg/util" "io" "io/ioutil" "net/http" @@ -28,17 +31,24 @@ type Client struct { buffer *bytes.Buffer Logger *logging.Logger baseURL string + cluster Cluster password string username string } -func New(poller *conf.Poller) (*Client, error) { +type Cluster struct { + name string + info string + uuid string + version [3]int +} + +func New(poller *conf.Poller, timeout time.Duration) (*Client, error) { var ( client Client httpclient *http.Client transport *http.Transport cert tls.Certificate - timeout time.Duration addr *string url string useInsecureTLS bool @@ -68,8 +78,8 @@ func New(poller *conf.Poller) (*Client, error) { // set authentication method if poller.AuthStyle != nil && *poller.AuthStyle == "certificate_auth" { - certPath := value(poller.SslCert, "") - keyPath := value(poller.SslKey, "") + certPath := util.Value(poller.SslCert, "") + keyPath := util.Value(poller.SslKey, "") if certPath == "" { return nil, errors.New(errors.MISSING_PARAM, "ssl_cert") } else if keyPath == "" { @@ -85,7 +95,7 @@ func New(poller *conf.Poller) (*Client, error) { InsecureSkipVerify: useInsecureTLS}, } } else { - username := value(poller.Username, "") + username := util.Value(poller.Username, "") password := poller.Password client.username = username client.password = password @@ -101,13 +111,6 @@ func New(poller *conf.Poller) (*Client, error) { } } - timeout = DefaultTimeout - if poller.ClientTimeout != nil { - timeout, err = time.ParseDuration(*poller.ClientTimeout) - if err != nil { - client.Logger.Error().Msgf("err paring client timeout of=[%s] err=%+v\n", timeout, err) - } - } client.Logger.Debug().Msgf("using timeout [%d]", timeout) httpclient = &http.Client{Transport: transport, Timeout: timeout} @@ -193,7 +196,7 @@ func downloadSwagger(poller *conf.Poller, path string, url string) (int64, error return 0, err } - if restClient, err = New(poller); err != nil { + if restClient, err = New(poller, DefaultTimeout); err != nil { return 0, fmt.Errorf("error creating new client %w\n", err) } @@ -213,3 +216,73 @@ func downloadSwagger(poller *conf.Poller, path string, url string) (int64, error } return n, nil } + +func (c *Client) Init(retries int) error { + + var ( + err error + content []byte + data map[string]interface{} + i int + ) + + for i = 0; i < retries; i++ { + + if content, err = c.GetRest(BuildHref("cluster", "*", nil, "", "", "", "")); err != nil { + continue + } + if err = json.Unmarshal(content, &data); err != nil { + return err + } + + results := gjson.GetManyBytes(content, "name", "uuid", "version.full", "version.generation", "version.major", "version.minor") + c.cluster.name = results[0].String() + c.cluster.uuid = results[1].String() + c.cluster.info = results[2].String() + c.cluster.version[0] = int(results[3].Int()) + c.cluster.version[1] = int(results[4].Int()) + c.cluster.version[2] = int(results[5].Int()) + return nil + } + return err +} + +func BuildHref(apiPath string, fields string, field []string, queryFields string, queryValue string, maxRecords string, returnTimeout string) string { + href := strings.Builder{} + href.WriteString("api/") + href.WriteString(apiPath) + href.WriteString("?return_records=true") + addArg(&href, "&fields=", fields) + for _, field := range field { + addArg(&href, "&", field) + } + addArg(&href, "&query_fields=", queryFields) + addArg(&href, "&query=", queryValue) + addArg(&href, "&max_records=", maxRecords) + addArg(&href, "&return_timeout=", returnTimeout) + return href.String() +} + +func addArg(href *strings.Builder, field string, value string) { + if value == "" { + return + } + href.WriteString(field) + href.WriteString(value) +} + +func (c *Client) ClusterName() string { + return c.cluster.name +} + +func (c *Client) ClusterUUID() string { + return c.cluster.uuid +} + +func (c *Client) Info() string { + return c.cluster.info +} + +func (c *Client) Version() [3]int { + return c.cluster.version +} diff --git a/cmd/tools/rest/rest.go b/cmd/tools/rest/rest.go index ddda2d44b..20cd8a1a6 100644 --- a/cmd/tools/rest/rest.go +++ b/cmd/tools/rest/rest.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/spf13/cobra" "goharvest2/pkg/conf" + "goharvest2/pkg/util" "io" "os" "path/filepath" @@ -165,7 +166,7 @@ func doData() { return } - if client, err = New(poller); err != nil { + if client, err = New(poller, DefaultTimeout); err != nil { fmt.Printf("error creating new client %+v\n", err) os.Exit(1) } @@ -175,8 +176,14 @@ func doData() { args.Api = args.Api[1:] } var records []interface{} - fetchData(client, buildHref(), &records) + href := BuildHref(args.Api, args.Fields, args.Field, args.QueryField, args.QueryValue, args.MaxRecords, "") + stderr("fetching href=[%s]\n", href) + err = FetchData(client, href, &records) + if err != nil { + stderr("error %+v\n", err) + return + } all := Pagination{ Records: records, NumRecords: len(records), @@ -199,50 +206,23 @@ func getPollerAndAddr() (*conf.Poller, string, error) { fmt.Printf("Poller named [%s] does not exist\n", args.Poller) return nil, "", err } - if addr = value(poller.Addr, ""); addr == "" { + if addr = util.Value(poller.Addr, ""); addr == "" { fmt.Printf("Poller named [%s] does not have a valid addr=[%s]\n", args.Poller, addr) return nil, "", err } return poller, addr, nil } -func buildHref() string { - href := strings.Builder{} - href.WriteString("api/") - href.WriteString(args.Api) - href.WriteString("?return_records=true") - addArg(&href, "&fields=", args.Fields) - for _, field := range args.Field { - addArg(&href, "&", field) - } - addArg(&href, "&query_fields=", args.QueryField) - addArg(&href, "&query=", args.QueryValue) - addArg(&href, "&max_records=", args.MaxRecords) - - return href.String() -} - -func addArg(href *strings.Builder, field string, value string) { - if value == "" { - return - } - href.WriteString(field) - href.WriteString(value) -} - -func fetchData(client *Client, href string, records *[]interface{}) { - stderr("fetching href=[%s]\n", href) +func FetchData(client *Client, href string, records *[]interface{}) error { getRest, err := client.GetRest(href) if err != nil { - stderr("error making request api=%s err=%+v\n", href, err) - return + return fmt.Errorf("error making request api=%s err=%+v\n", href, err) } else { // extract returned records since paginated records need to be merged into a single list var page Pagination err := json.Unmarshal(getRest, &page) if err != nil { - stderr("error unmarshalling json %+v\n", err) - return + return fmt.Errorf("error unmarshalling json %+v\n", err) } *records = append(*records, page.Records...) @@ -253,12 +233,13 @@ func fetchData(client *Client, href string, records *[]interface{}) { if nextLink != "" { if nextLink == href { // nextLink is same as previous link, no progress is being made, exit - return + return nil } - fetchData(client, nextLink, records) + FetchData(client, nextLink, records) } } } + return nil } func stderr(format string, a ...interface{}) { diff --git a/cmd/tools/rest/swag.go b/cmd/tools/rest/swag.go index 868350918..46f304899 100644 --- a/cmd/tools/rest/swag.go +++ b/cmd/tools/rest/swag.go @@ -435,10 +435,3 @@ func showApis(ontapSwag ontap) { } table.Render() } - -func value(ptr *string, nilValue string) string { - if ptr == nil { - return nilValue - } - return *ptr -} diff --git a/conf/rest/9.7.0/disk.yaml b/conf/rest/9.7.0/disk.yaml new file mode 100644 index 000000000..22ca034cc --- /dev/null +++ b/conf/rest/9.7.0/disk.yaml @@ -0,0 +1,37 @@ + +name: Disk +query: storage/disks +object: disk + +counters: + - ^name => disk + - ^^uid + - ^serial_number + - ^model + - ^vendor + - ^firmware_version + - usable_size + - rated_life_used_percent + - ^type + - ^container_type + - ^pool + - ^state + - ^node.uuid + - ^node.name => node + - ^home_node.uuid + - ^home_node.name + - ^shelf.uid => shelf + - ^bay => shelf_bay + - ?self_encrypting + - ?fips_certified + +export_options: + instance_keys: + - node + - disk + instance_labels: + - type + - model + - shelf + - shelf_bay + - serial_number diff --git a/conf/rest/default.yaml b/conf/rest/default.yaml new file mode 100644 index 000000000..b057f9a44 --- /dev/null +++ b/conf/rest/default.yaml @@ -0,0 +1,8 @@ + +collector: Rest + +schedule: + - data: 180s + +objects: + Disk: disk.yaml diff --git a/pkg/matrix/metric.go b/pkg/matrix/metric.go index 4015fd8e6..d0667b8ac 100644 --- a/pkg/matrix/metric.go +++ b/pkg/matrix/metric.go @@ -51,6 +51,7 @@ type Metric interface { SetValueFloat64(*Instance, float64) error SetValueString(*Instance, string) error SetValueBytes(*Instance, []byte) error + SetValueBool(*Instance, bool) error AddValueInt(*Instance, int) error AddValueInt32(*Instance, int32) error @@ -208,3 +209,7 @@ func (me *AbstractMetric) MultiplyByScalar(s int) error { func (me *AbstractMetric) AddValueString(i *Instance, s string) error { return errors.New(errors.ERR_IMPLEMENT, me.dtype) } + +func (me *AbstractMetric) SetValueBool(i *Instance, b bool) error { + return errors.New(errors.ERR_IMPLEMENT, me.dtype) +} diff --git a/pkg/matrix/metric_uint8.go b/pkg/matrix/metric_uint8.go index 93cd686f1..1842fbf29 100644 --- a/pkg/matrix/metric_uint8.go +++ b/pkg/matrix/metric_uint8.go @@ -75,6 +75,16 @@ func (me *MetricUint8) SetValueInt64(i *Instance, v int64) error { return errors.New(OVERFLOW_ERROR, fmt.Sprintf("convert int64 (%d) to uint32", v)) } +func (me *MetricUint8) SetValueBool(i *Instance, v bool) error { + me.record[i.index] = true + if v { + me.values[i.index] = 1 + } else { + me.values[i.index] = 0 + } + return nil +} + func (me *MetricUint8) SetValueUint8(i *Instance, v uint8) error { me.record[i.index] = true me.values[i.index] = v diff --git a/pkg/util/util.go b/pkg/util/util.go index 087eaf2ad..14621c986 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "os" "os/exec" + "reflect" "runtime" "strconv" "strings" @@ -180,6 +181,13 @@ func ContainsWholeWord(source string, search string) bool { return false } +func Value(ptr *string, nilValue string) string { + if ptr == nil { + return nilValue + } + return *ptr +} + func Contains(s []string, e string) bool { for _, a := range s { if a == e { @@ -188,3 +196,27 @@ func Contains(s []string, e string) bool { } return false } + +func Intersection(a interface{}, b interface{}) ([]interface{}, []interface{}) { + matches := make([]interface{}, 0) + misses := make([]interface{}, 0) + hash := make(map[interface{}]bool) + av := reflect.ValueOf(a) + bv := reflect.ValueOf(b) + + for i := 0; i < av.Len(); i++ { + el := av.Index(i).Interface() + hash[el] = true + } + + for i := 0; i < bv.Len(); i++ { + el := bv.Index(i).Interface() + if _, found := hash[el]; found { + matches = append(matches, el) + } else { + misses = append(misses, el) + } + } + + return matches, misses +}