From 70b43c30e47f4fde62385c048106ac71f82195a2 Mon Sep 17 00:00:00 2001 From: Gianmaria Del Monte <39946305+gmgigi96@users.noreply.github.com> Date: Mon, 3 Jul 2023 07:44:08 +0200 Subject: [PATCH] New reva config (#4015) --- changelog/unreleased/new-config.md | 11 + cmd/revad/main.go | 120 +++- cmd/revad/pkg/config/common.go | 226 +++++++ cmd/revad/pkg/config/config.go | 169 +++++ cmd/revad/pkg/config/config_test.go | 592 ++++++++++++++++++ cmd/revad/pkg/config/dump.go | 146 +++++ cmd/revad/pkg/config/dump_test.go | 247 ++++++++ cmd/revad/pkg/config/grpc.go | 77 +++ cmd/revad/pkg/config/http.go | 77 +++ cmd/revad/pkg/config/lookup.go | 257 ++++++++ cmd/revad/pkg/config/lookup_test.go | 198 ++++++ cmd/revad/pkg/config/parser.go | 100 +++ cmd/revad/pkg/config/parser_test.go | 97 +++ cmd/revad/pkg/config/serverless.go | 54 ++ cmd/revad/pkg/config/templates.go | 279 +++++++++ cmd/revad/pkg/config/templates_test.go | 135 ++++ cmd/revad/{internal => pkg}/grace/grace.go | 212 +++++-- cmd/revad/runtime/grpc.go | 150 +++++ cmd/revad/runtime/http.go | 88 +++ cmd/revad/runtime/option.go | 13 +- cmd/revad/runtime/runtime.go | 541 ++++++++-------- go.mod | 1 + go.sum | 2 + .../applicationauth/applicationauth.go | 2 +- .../grpc/services/appprovider/appprovider.go | 2 +- .../grpc/services/appregistry/appregistry.go | 2 +- .../services/appregistry/appregistry_test.go | 2 +- .../services/authprovider/authprovider.go | 2 +- .../services/authregistry/authregistry.go | 2 +- internal/grpc/services/datatx/datatx.go | 2 +- internal/grpc/services/gateway/gateway.go | 2 +- .../services/groupprovider/groupprovider.go | 2 +- .../grpc/services/helloworld/helloworld.go | 2 +- internal/grpc/services/ocmcore/ocmcore.go | 2 +- .../ocminvitemanager/ocminvitemanager.go | 2 +- .../ocmproviderauthorizer.go | 2 +- .../ocmshareprovider/ocmshareprovider.go | 2 +- .../grpc/services/permissions/permissions.go | 2 +- .../grpc/services/preferences/preferences.go | 2 +- .../publicshareprovider.go | 2 +- .../publicstorageprovider.go | 2 +- .../storageprovider/storageprovider.go | 2 +- .../storageregistry/storageregistry.go | 2 +- .../services/userprovider/userprovider.go | 2 +- .../usershareprovider/usershareprovider.go | 2 +- pkg/rgrpc/option.go | 62 ++ pkg/rgrpc/rgrpc.go | 235 ++----- pkg/rhttp/rhttp.go | 232 +++---- .../config.go => pkg/rserverless/option.go | 29 +- pkg/rserverless/rserverless.go | 66 +- pkg/sharedconf/sharedconf.go | 48 +- pkg/sharedconf/sharedconf_test.go | 62 +- pkg/utils/list/list.go | 22 + pkg/utils/maps/maps.go | 50 ++ pkg/utils/net/net.go | 54 ++ 55 files changed, 3853 insertions(+), 843 deletions(-) create mode 100644 changelog/unreleased/new-config.md create mode 100644 cmd/revad/pkg/config/common.go create mode 100644 cmd/revad/pkg/config/config.go create mode 100644 cmd/revad/pkg/config/config_test.go create mode 100644 cmd/revad/pkg/config/dump.go create mode 100644 cmd/revad/pkg/config/dump_test.go create mode 100644 cmd/revad/pkg/config/grpc.go create mode 100644 cmd/revad/pkg/config/http.go create mode 100644 cmd/revad/pkg/config/lookup.go create mode 100644 cmd/revad/pkg/config/lookup_test.go create mode 100644 cmd/revad/pkg/config/parser.go create mode 100644 cmd/revad/pkg/config/parser_test.go create mode 100644 cmd/revad/pkg/config/serverless.go create mode 100644 cmd/revad/pkg/config/templates.go create mode 100644 cmd/revad/pkg/config/templates_test.go rename cmd/revad/{internal => pkg}/grace/grace.go (68%) create mode 100644 cmd/revad/runtime/grpc.go create mode 100644 cmd/revad/runtime/http.go create mode 100644 pkg/rgrpc/option.go rename cmd/revad/internal/config/config.go => pkg/rserverless/option.go (61%) create mode 100644 pkg/utils/maps/maps.go create mode 100644 pkg/utils/net/net.go diff --git a/changelog/unreleased/new-config.md b/changelog/unreleased/new-config.md new file mode 100644 index 0000000000..9f9f1f0124 --- /dev/null +++ b/changelog/unreleased/new-config.md @@ -0,0 +1,11 @@ +Enhancement: New configuration + +Allow multiple driverts of the same service to be in the +same toml config. Add a `vars` section to contain common +parameters addressable using templates in the configuration +of the different drivers. Support templating to reference +values of other parameters in the configuration. +Assign random ports to services where the address is not +specified. + +https://github.com/cs3org/reva/pull/4015 diff --git a/cmd/revad/main.go b/cmd/revad/main.go index cfb1a5cdb4..991073ace7 100644 --- a/cmd/revad/main.go +++ b/cmd/revad/main.go @@ -21,6 +21,7 @@ package main import ( "flag" "fmt" + "io" "io/fs" "os" "path" @@ -28,11 +29,14 @@ import ( "sync" "syscall" - "github.com/cs3org/reva/cmd/revad/internal/config" - "github.com/cs3org/reva/cmd/revad/internal/grace" + "github.com/cs3org/reva/cmd/revad/pkg/config" + "github.com/cs3org/reva/cmd/revad/pkg/grace" "github.com/cs3org/reva/cmd/revad/runtime" + "github.com/cs3org/reva/pkg/logger" "github.com/cs3org/reva/pkg/sysinfo" "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/rs/zerolog" ) var ( @@ -41,13 +45,16 @@ var ( signalFlag = flag.String("s", "", "send signal to a master process: stop, quit, reload") configFlag = flag.String("c", "/etc/revad/revad.toml", "set configuration file") pidFlag = flag.String("p", "", "pid file. If empty defaults to a random file in the OS temporary directory") - logFlag = flag.String("log", "", "log messages with the given severity or above. One of: [trace, debug, info, warn, error, fatal, panic]") dirFlag = flag.String("dev-dir", "", "runs any toml file in the specified directory. Intended for development use only") // Compile time variables initialized with gcc flags. gitCommit, buildDate, version, goVersion string ) +var ( + revaProcs []*runtime.Reva +) + func main() { flag.Parse() @@ -126,7 +133,7 @@ func handleSignalFlag() { // kill process with signal if err := process.Signal(signal); err != nil { - fmt.Fprintf(os.Stderr, "error signaling process %d with signal %s\n", process.Pid, signal) + fmt.Fprintf(os.Stderr, "error signaling process %d with signal %s: %v\n", process.Pid, signal, err) os.Exit(1) } @@ -134,7 +141,7 @@ func handleSignalFlag() { } } -func getConfigs() ([]map[string]interface{}, error) { +func getConfigs() ([]*config.Config, error) { var confs []string // give priority to read from dev-dir if *dirFlag != "" { @@ -186,8 +193,8 @@ func getConfigsFromDir(dir string) (confs []string, err error) { return } -func readConfigs(files []string) ([]map[string]interface{}, error) { - confs := make([]map[string]interface{}, 0, len(files)) +func readConfigs(files []string) ([]*config.Config, error) { + confs := make([]*config.Config, 0, len(files)) for _, conf := range files { fd, err := os.Open(conf) if err != nil { @@ -195,49 +202,114 @@ func readConfigs(files []string) ([]map[string]interface{}, error) { } defer fd.Close() - v, err := config.Read(fd) + c, err := config.Load(fd) if err != nil { return nil, err } - confs = append(confs, v) + confs = append(confs, c) } return confs, nil } -func runConfigs(confs []map[string]interface{}) { +func runConfigs(confs []*config.Config) { + pidfile := getPidfile() if len(confs) == 1 { - runSingle(confs[0]) + runSingle(confs[0], pidfile) return } runMultiple(confs) } -func runSingle(conf map[string]interface{}) { - if *pidFlag == "" { - *pidFlag = getPidfile() - } - - runtime.Run(conf, *pidFlag, *logFlag) +func registerReva(r *runtime.Reva) { + revaProcs = append(revaProcs, r) } -func getPidfile() string { - uuid := uuid.New().String() - name := fmt.Sprintf("revad-%s.pid", uuid) +func runSingle(conf *config.Config, pidfile string) { + log := initLogger(conf.Log) + reva, err := runtime.New(conf, + runtime.WithPidFile(pidfile), + runtime.WithLogger(log), + ) + if err != nil { + abort(log, "error creating reva runtime: %v", err) + } + registerReva(reva) + if err := reva.Start(); err != nil { + abort(log, "error starting reva: %v", err) + } +} - return path.Join(os.TempDir(), name) +func abort(log *zerolog.Logger, format string, a ...any) { + log.Fatal().Msgf(format, a...) } -func runMultiple(confs []map[string]interface{}) { +func runMultiple(confs []*config.Config) { var wg sync.WaitGroup + for _, conf := range confs { wg.Add(1) pidfile := getPidfile() - go func(wg *sync.WaitGroup, conf map[string]interface{}) { + go func(wg *sync.WaitGroup, conf *config.Config) { defer wg.Done() - runtime.Run(conf, pidfile, *logFlag) + runSingle(conf, pidfile) }(&wg, conf) } wg.Wait() os.Exit(0) } + +func getPidfile() string { + uuid := uuid.New().String() + name := fmt.Sprintf("revad-%s.pid", uuid) + + return path.Join(os.TempDir(), name) +} + +func initLogger(conf *config.Log) *zerolog.Logger { + log, err := newLogger(conf) + if err != nil { + fmt.Fprintf(os.Stderr, "error creating logger: %v", err) + os.Exit(1) + } + return log +} + +func newLogger(conf *config.Log) (*zerolog.Logger, error) { + // TODO(labkode): use debug level rather than info as default until reaching a stable version. + // Helps having smaller development files. + if conf.Level == "" { + conf.Level = zerolog.DebugLevel.String() + } + + var opts []logger.Option + opts = append(opts, logger.WithLevel(conf.Level)) + + w, err := getWriter(conf.Output) + if err != nil { + return nil, err + } + + opts = append(opts, logger.WithWriter(w, logger.Mode(conf.Mode))) + + l := logger.New(opts...) + sub := l.With().Int("pid", os.Getpid()).Logger() + return &sub, nil +} + +func getWriter(out string) (io.Writer, error) { + if out == "stderr" || out == "" { + return os.Stderr, nil + } + + if out == "stdout" { + return os.Stdout, nil + } + + fd, err := os.OpenFile(out, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, errors.Wrap(err, "error creating log file: "+out) + } + + return fd, nil +} diff --git a/cmd/revad/pkg/config/common.go b/cmd/revad/pkg/config/common.go new file mode 100644 index 0000000000..fa4ce3148b --- /dev/null +++ b/cmd/revad/pkg/config/common.go @@ -0,0 +1,226 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "fmt" + "net" + + "github.com/mitchellh/mapstructure" +) + +type iterable interface { + services() map[string]ServicesConfig + interceptors() map[string]map[string]any +} + +type iterableImpl struct{ i iterable } + +// ServicesConfig holds the configuration for reva services. +type ServicesConfig []*DriverConfig + +// DriversNumber return the number of driver configured for the service. +func (s ServicesConfig) DriversNumber() int { return len(s) } + +// DriverConfig holds the configuration for a driver. +type DriverConfig struct { + Address Address `key:"address"` + Network string `key:"network"` + Label string `key:"-"` + Config map[string]any `key:",squash"` // this must be at the bottom! +} + +// Add appends the driver configuration to the given list of services. +func (s *ServicesConfig) Add(domain, svc string, c *DriverConfig) { + l := len(*s) + if l == 0 { + // the label is simply the service name + c.Label = domain + "_" + svc + } else { + c.Label = label(domain, svc, l) + if l == 1 { + (*s)[0].Label = label(domain, svc, 0) + } + } + *s = append(*s, c) +} + +func newSvcConfigFromList(domain, name string, l []map[string]any) (ServicesConfig, error) { + cfg := make(ServicesConfig, 0, len(l)) + for _, c := range l { + cfg.Add(domain, name, &DriverConfig{Config: c}) + } + return cfg, nil +} + +func newSvcConfigFromMap(domain, name string, m map[string]any) ServicesConfig { + s, _ := newSvcConfigFromList(domain, name, []map[string]any{m}) + return s +} + +func parseServices(domain string, cfg map[string]any) (map[string]ServicesConfig, error) { + // parse services + svcCfg, ok := cfg["services"].(map[string]any) + if !ok { + return nil, fmt.Errorf("%s.services must be a map", domain) + } + + services := make(map[string]ServicesConfig) + for name, cfg := range svcCfg { + // cfg can be a list or a map + cfgLst, ok := cfg.([]map[string]any) + if ok { + s, err := newSvcConfigFromList(domain, name, cfgLst) + if err != nil { + return nil, err + } + services[name] = s + continue + } + cfgMap, ok := cfg.(map[string]any) + if !ok { + return nil, fmt.Errorf("%s.services.%s must be a list or a map. got %T", domain, name, cfg) + } + services[name] = newSvcConfigFromMap(domain, name, cfgMap) + } + + return services, nil +} + +func parseMiddlwares(cfg map[string]any, key string) (map[string]map[string]any, error) { + m := make(map[string]map[string]any) + + mid, ok := cfg[key] + if !ok { + return m, nil + } + + if err := mapstructure.Decode(mid, &m); err != nil { + return nil, err + } + return m, nil +} + +// Service contains the configuration for a service. +type Service struct { + Address Address + Network string + Name string + Label string + Config map[string]any + + raw *DriverConfig +} + +// SetAddress sets the address for the service in the configuration. +func (s *Service) SetAddress(address Address) { + s.Address = address + s.raw.Address = address +} + +// ServiceFunc is an helper function used to pass the service config +// to the ForEachService func. +type ServiceFunc func(*Service) + +// Interceptor contains the configuration for an interceptor. +type Interceptor struct { + Name string + Config map[string]any +} + +// InterceptorFunc is an helper function used to pass the interface config +// to the ForEachInterceptor func. +type InterceptorFunc func(*Interceptor) + +// ForEachService iterates to each service/driver calling the function f. +func (i iterableImpl) ForEachService(f ServiceFunc) { + if i.i == nil { + return + } + for name, c := range i.i.services() { + for _, cfg := range c { + f(&Service{ + raw: cfg, + Address: cfg.Address, + Network: cfg.Network, + Label: cfg.Label, + Name: name, + Config: cfg.Config, + }) + } + } +} + +func label(domain, name string, i int) string { + return fmt.Sprintf("%s_%s_%d", domain, name, i) +} + +// ForEachInterceptor iterates to each middleware calling the function f. +func (i iterableImpl) ForEachInterceptor(f InterceptorFunc) { + for name, c := range i.i.interceptors() { + f(&Interceptor{ + Name: name, + Config: c, + }) + } +} + +func addressForService(global Address, cfg map[string]any) Address { + if address, ok := cfg["address"].(string); ok { + return Address(address) + } + return global +} + +func networkForService(global string, cfg map[string]any) string { + if network, ok := cfg["network"].(string); ok { + return network + } + return global +} + +// Address is the data structure holding an address. +type Address string + +// ensure Address implements the Lookuper interface. +var _ Lookuper = (*Address)(nil) + +// String return the string representation of the address. +func (a Address) String() string { return string(a) } + +// Get returns the value associated to the given key. +// The key available for an Address type are "port" and "ip", +// allowing respectively to get the port and the ip from the address. +func (a Address) Lookup(k string) (any, error) { + switch k { + case "port": + t, err := net.ResolveTCPAddr("tcp", a.String()) + if err != nil { + return nil, err + } + return t.Port, nil + case "ip": + t, err := net.ResolveTCPAddr("tcp", a.String()) + if err != nil { + return nil, err + } + return t.IP.String(), nil + } + return nil, ErrKeyNotFound{Key: k} +} diff --git a/cmd/revad/pkg/config/config.go b/cmd/revad/pkg/config/config.go new file mode 100644 index 0000000000..846b03374e --- /dev/null +++ b/cmd/revad/pkg/config/config.go @@ -0,0 +1,169 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "fmt" + "io" + "reflect" + + "github.com/BurntSushi/toml" + "github.com/creasty/defaults" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" +) + +// Config holds the reva configuration. +type Config struct { + GRPC *GRPC `key:"grpc" mapstructure:"-" default:"{}"` + HTTP *HTTP `key:"http" mapstructure:"-" default:"{}"` + Serverless *Serverless `key:"serverless" mapstructure:"-" default:"{}"` + Shared *Shared `key:"shared" mapstructure:"shared" default:"{}"` + Log *Log `key:"log" mapstructure:"log" default:"{}" template:"-"` + Core *Core `key:"core" mapstructure:"core" default:"{}" template:"-"` + Vars Vars `key:"vars" mapstructure:"vars" default:"{}" template:"-"` +} + +// Log holds the configuration for the logger. +type Log struct { + Output string `key:"output" mapstructure:"output" default:"stdout"` + Mode string `key:"mode" mapstructure:"mode" default:"console"` + Level string `key:"level" mapstructure:"level" default:"trace"` +} + +// Shared holds the shared configuration. +type Shared struct { + JWTSecret string `key:"jwt_secret" mapstructure:"jwt_secret" default:"changemeplease"` + GatewaySVC string `key:"gatewaysvc" mapstructure:"gatewaysvc" default:"0.0.0.0:19000"` + DataGateway string `key:"datagateway" mapstructure:"datagateway" default:"http://0.0.0.0:19001/datagateway"` + SkipUserGroupsInToken bool `key:"skip_user_groups_in_token" mapstructure:"skip_user_groups_in_token"` + BlockedUsers []string `key:"blocked_users" mapstructure:"blocked_users" default:"[]"` +} + +// Core holds the core configuration. +type Core struct { + MaxCPUs string `key:"max_cpus" mapstructure:"max_cpus"` + TracingEnabled bool `key:"tracing_enabled" mapstructure:"tracing_enabled" default:"true"` + TracingEndpoint string `key:"tracing_endpoint" mapstructure:"tracing_endpoint" default:"localhost:6831"` + TracingCollector string `key:"tracing_collector" mapstructure:"tracing_collector"` + TracingServiceName string `key:"tracing_service_name" mapstructure:"tracing_service_name"` + TracingService string `key:"tracing_service" mapstructure:"tracing_service"` +} + +// Vars holds the a set of configuration paramenters that +// can be references by other parts of the configuration. +type Vars map[string]any + +// Lookuper is the interface for getting the value +// associated with a given key. +type Lookuper interface { + // Lookup get the value associated to thye given key. + // It returns ErrKeyNotFound if the key does not exists. + Lookup(key string) (any, error) +} + +// Load loads the configuration from the reader. +func Load(r io.Reader) (*Config, error) { + var c Config + if err := defaults.Set(&c); err != nil { + return nil, err + } + var raw map[string]any + if _, err := toml.NewDecoder(r).Decode(&raw); err != nil { + return nil, errors.Wrap(err, "config: error decoding toml data") + } + if err := c.parse(raw); err != nil { + return nil, err + } + return &c, nil +} + +func (c *Config) parse(raw map[string]any) error { + if err := c.parseGRPC(raw); err != nil { + return err + } + if err := c.parseHTTP(raw); err != nil { + return err + } + if err := c.parseServerless(raw); err != nil { + return err + } + if err := mapstructure.Decode(raw, c); err != nil { + return err + } + return nil +} + +// ApplyTemplates applies the templates defined in the configuration, +// replacing the template string with the value pointed by the given key. +func (c *Config) ApplyTemplates(l Lookuper) error { + return applyTemplateByType(l, nil, reflect.ValueOf(c)) +} + +// Dump returns the configuration as a map. +func (c *Config) Dump() map[string]any { + v := dumpByType(reflect.ValueOf(c)) + dump, ok := v.(map[string]any) + if !ok { + panic(fmt.Sprintf("dump should be a map: got %T", dump)) + } + return dump +} + +// Lookup gets the value associated to the given key in the config. +// The key is in the form .[], allowing accessing +// recursively the config on subfields, in case of maps or structs or +// types implementing the Getter interface, or elements in a list by the +// given index. +func (c *Config) Lookup(key string) (any, error) { + // check thet key is valid, meaning it starts with one of + // the fields of the config struct + if !c.isValidKey(key) { + return nil, nil + } + val, err := lookupByType(key, reflect.ValueOf(c)) + if err != nil { + return nil, errors.Wrapf(err, "lookup: error on key '%s'", key) + } + return val, nil +} + +func (c *Config) isValidKey(key string) bool { + cmd, _, err := parseNext(key) + if err != nil { + return false + } + f, ok := cmd.(FieldByKey) + if !ok { + return false + } + k := f.Key + e := reflect.TypeOf(c).Elem() + for i := 0; i < e.NumField(); i++ { + f := e.Field(i) + prefix := f.Tag.Get("key") + if prefix == "" || prefix == "-" { + continue + } + if k == prefix { + return true + } + } + return false +} diff --git a/cmd/revad/pkg/config/config_test.go b/cmd/revad/pkg/config/config_test.go new file mode 100644 index 0000000000..7ab22f8a03 --- /dev/null +++ b/cmd/revad/pkg/config/config_test.go @@ -0,0 +1,592 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoadGlobalGRPCAddress(t *testing.T) { + config := ` +[grpc] +address = "localhost:9142" + +[[grpc.services.authprovider]] +driver = "demo" +address = "localhost:9000" + +[grpc.services.authprovider.drivers.demo] +key = "value" + +[[grpc.services.authprovider]] +driver = "machine" +address = "localhost:9001" + +[grpc.services.authprovider.drivers.machine] +key = "value" + +[grpc.services.gateway] +something = "test"` + + c, err := Load(strings.NewReader(config)) + if err != nil { + t.Fatalf("not expected error: %v", err) + } + + assert.Equal(t, Address("localhost:9142"), c.GRPC.Address) + + exp := map[string]ServicesConfig{ + "authprovider": []*DriverConfig{ + { + Address: "localhost:9000", + Config: map[string]any{ + "driver": "demo", + "drivers": map[string]any{ + "demo": map[string]any{ + "key": "value", + }, + }, + "address": "localhost:9000", + }, + Network: "tcp", + Label: "grpc_authprovider_0", + }, + { + Address: "localhost:9001", + Config: map[string]any{ + "driver": "machine", + "address": "localhost:9001", + "drivers": map[string]any{ + "machine": map[string]any{ + "key": "value", + }, + }, + }, + Network: "tcp", + Label: "grpc_authprovider_1", + }, + }, + "gateway": []*DriverConfig{ + { + Address: "localhost:9142", + Config: map[string]any{ + "something": "test", + }, + Network: "tcp", + Label: "grpc_gateway", + }, + }, + } + assert.Equal(t, exp, c.GRPC.Services) +} + +func TestLoadNoGRPCDefaultAddress(t *testing.T) { + config := ` +[[grpc.services.authprovider]] +driver = "demo" +address = "localhost:9000" + +[grpc.services.authprovider.drivers.demo] +key = "value" + +[[grpc.services.authprovider]] +driver = "machine" +address = "localhost:9001" + +[grpc.services.authprovider.drivers.machine] +key = "value" + +[grpc.services.gateway] +something = "test"` + + c, err := Load(strings.NewReader(config)) + if err != nil { + t.Fatalf("not expected error: %v", err) + } + + assert.Equal(t, Address(""), c.GRPC.Address) + + exp := map[string]ServicesConfig{ + "authprovider": []*DriverConfig{ + { + Address: "localhost:9000", + Config: map[string]any{ + "driver": "demo", + "drivers": map[string]any{ + "demo": map[string]any{ + "key": "value", + }, + }, + "address": "localhost:9000", + }, + Network: "tcp", + Label: "grpc_authprovider_0", + }, + { + Address: "localhost:9001", + Config: map[string]any{ + "driver": "machine", + "address": "localhost:9001", + "drivers": map[string]any{ + "machine": map[string]any{ + "key": "value", + }, + }, + }, + Network: "tcp", + Label: "grpc_authprovider_1", + }, + }, + "gateway": []*DriverConfig{ + { + Config: map[string]any{ + "something": "test", + }, + Network: "tcp", + Label: "grpc_gateway", + }, + }, + } + assert.Equal(t, exp, c.GRPC.Services) +} + +func TestLoadFullConfig(t *testing.T) { + config := ` +[shared] +gatewaysvc = "localhost:9142" +jwt_secret = "secret" + +[log] +output = "/var/log/revad/revad-gateway.log" +mode = "json" +level = "trace" + +[core] +max_cpus = "1" +tracing_enabled = true + +[vars] +db_username = "root" +db_password = "secretpassword" + +[grpc] +shutdown_deadline = 10 +enable_reflection = true + +[grpc.services.gateway] +authregistrysvc = "{{ grpc.services.authregistry.address }}" + +[grpc.services.authregistry] +driver = "static" + +[grpc.services.authregistry.drivers.static.rules] +basic = "{{ grpc.services.authprovider[0].address }}" +machine = "{{ grpc.services.authprovider[1].address }}" + +[[grpc.services.authprovider]] +driver = "ldap" +address = "localhost:19000" + +[grpc.services.authprovider.drivers.ldap] +password = "ldap" + +[[grpc.services.authprovider]] +driver = "machine" +address = "localhost:19001" + +[grpc.services.authprovider.drivers.machine] +api_key = "secretapikey" + +[http] +address = "localhost:19002" + +[http.services.dataprovider] +driver = "localhome" + +[http.services.sysinfo] + +[serverless.services.notifications] +nats_address = "nats-server-01.example.com" +nats_token = "secret-token-example"` + + c2, err := Load(strings.NewReader(config)) + assert.ErrorIs(t, err, nil) + + assert.Equal(t, &Shared{ + GatewaySVC: "localhost:9142", + JWTSecret: "secret", + DataGateway: "http://0.0.0.0:19001/datagateway", + BlockedUsers: []string{}, + }, c2.Shared) + + assert.Equal(t, &Log{ + Output: "/var/log/revad/revad-gateway.log", + Mode: "json", + Level: "trace", + }, c2.Log) + + assert.Equal(t, &Core{ + MaxCPUs: "1", + TracingEnabled: true, + TracingEndpoint: "localhost:6831", + }, c2.Core) + + assert.Equal(t, Vars{ + "db_username": "root", + "db_password": "secretpassword", + }, c2.Vars) + + assertGRPCEqual(t, &GRPC{ + ShutdownDeadline: 10, + EnableReflection: true, + Network: "tcp", + Interceptors: make(map[string]map[string]any), + Services: map[string]ServicesConfig{ + "gateway": { + { + Config: map[string]any{ + "authregistrysvc": "{{ grpc.services.authregistry.address }}", + }, + Label: "grpc_gateway", + Network: "tcp", + }, + }, + "authregistry": { + { + Config: map[string]any{ + "driver": "static", + "drivers": map[string]any{ + "static": map[string]any{ + "rules": map[string]any{ + "basic": "{{ grpc.services.authprovider[0].address }}", + "machine": "{{ grpc.services.authprovider[1].address }}", + }, + }, + }, + }, + Label: "grpc_authregistry", + Network: "tcp", + }, + }, + "authprovider": { + { + Address: "localhost:19000", + Config: map[string]any{ + "driver": "ldap", + "address": "localhost:19000", + "drivers": map[string]any{ + "ldap": map[string]any{ + "password": "ldap", + }, + }, + }, + Label: "grpc_authprovider_0", + Network: "tcp", + }, + { + Address: "localhost:19001", + Config: map[string]any{ + "driver": "machine", + "address": "localhost:19001", + "drivers": map[string]any{ + "machine": map[string]any{ + "api_key": "secretapikey", + }, + }, + }, + Label: "grpc_authprovider_1", + Network: "tcp", + }, + }, + }, + }, c2.GRPC) + + assertHTTPEqual(t, &HTTP{ + Address: Address("localhost:19002"), + Network: "tcp", + Middlewares: make(map[string]map[string]any), + Services: map[string]ServicesConfig{ + "dataprovider": { + { + Address: "localhost:19002", + Config: map[string]any{ + "driver": "localhome", + }, + Network: "tcp", + Label: "http_dataprovider", + }, + }, + "sysinfo": { + { + Address: "localhost:19002", + Config: map[string]any{}, + Network: "tcp", + Label: "http_sysinfo", + }, + }, + }, + }, c2.HTTP) + + assert.Equal(t, &Serverless{ + Services: map[string]map[string]any{ + "notifications": { + "nats_address": "nats-server-01.example.com", + "nats_token": "secret-token-example", + }, + }, + }, c2.Serverless) +} + +func assertGRPCEqual(t *testing.T, g1, g2 *GRPC) { + assert.Equal(t, g1.Address, g2.Address) + assert.Equal(t, g1.Network, g2.Network) + assert.Equal(t, g1.ShutdownDeadline, g2.ShutdownDeadline) + assert.Equal(t, g1.EnableReflection, g2.EnableReflection) + assert.Equal(t, g1.Services, g2.Services) + assert.Equal(t, g1.Interceptors, g2.Interceptors) +} + +func assertHTTPEqual(t *testing.T, h1, h2 *HTTP) { + assert.Equal(t, h1.Network, h2.Network) + assert.Equal(t, h1.Network, h2.Network) + assert.Equal(t, h1.CertFile, h2.CertFile) + assert.Equal(t, h1.KeyFile, h2.KeyFile) + assert.Equal(t, h1.Services, h2.Services) + assert.Equal(t, h1.Middlewares, h2.Middlewares) +} + +func TestDump(t *testing.T) { + config := &Config{ + Shared: &Shared{ + GatewaySVC: "localhost:9142", + JWTSecret: "secret", + }, + Log: &Log{ + Output: "/var/log/revad/revad-gateway.log", + Mode: "json", + Level: "trace", + }, + Core: &Core{ + MaxCPUs: "1", + TracingEnabled: true, + }, + Vars: Vars{ + "db_username": "root", + "db_password": "secretpassword", + }, + GRPC: &GRPC{ + ShutdownDeadline: 10, + EnableReflection: true, + Interceptors: make(map[string]map[string]any), + Services: map[string]ServicesConfig{ + "gateway": { + { + Config: map[string]any{ + "authregistrysvc": "localhost:19000", + }, + }, + }, + "authregistry": { + { + Address: "localhost:19000", + Config: map[string]any{ + "driver": "static", + "drivers": map[string]any{ + "static": map[string]any{ + "rules": map[string]any{ + "basic": "localhost:19001", + "machine": "localhost:19002", + }, + }, + }, + }, + }, + }, + "authprovider": { + { + Address: "localhost:19001", + Config: map[string]any{ + "driver": "ldap", + "address": "localhost:19001", + "drivers": map[string]any{ + "ldap": map[string]any{ + "password": "ldap", + }, + }, + }, + }, + { + Address: "localhost:19002", + Config: map[string]any{ + "driver": "machine", + "address": "localhost:19002", + "drivers": map[string]any{ + "machine": map[string]any{ + "api_key": "secretapikey", + }, + }, + }, + }, + }, + }, + }, + HTTP: &HTTP{ + Address: "localhost:19003", + Middlewares: make(map[string]map[string]any), + Services: map[string]ServicesConfig{ + "dataprovider": { + { + Address: "localhost:19003", + Config: map[string]any{ + "driver": "localhome", + }, + }, + }, + "sysinfo": { + { + Address: "localhost:19003", + Config: map[string]any{}, + }, + }, + }, + }, + Serverless: &Serverless{ + Services: map[string]map[string]any{ + "notifications": { + "nats_address": "nats-server-01.example.com", + "nats_token": "secret-token-example", + }, + }, + }, + } + + m := config.Dump() + assert.Equal(t, map[string]any{ + "shared": map[string]any{ + "jwt_secret": "secret", + "gatewaysvc": "localhost:9142", + "datagateway": "", + "skip_user_groups_in_token": false, + "blocked_users": []any{}, + }, + "log": map[string]any{ + "output": "/var/log/revad/revad-gateway.log", + "mode": "json", + "level": "trace", + }, + "core": map[string]any{ + "max_cpus": "1", + "tracing_enabled": true, + "tracing_endpoint": "", + "tracing_collector": "", + "tracing_service_name": "", + "tracing_service": "", + }, + "vars": map[string]any{ + "db_username": "root", + "db_password": "secretpassword", + }, + "grpc": map[string]any{ + "address": Address(""), + "network": "", + "shutdown_deadline": 10, + "enable_reflection": true, + "interceptors": map[string]any{}, + "services": map[string]any{ + "gateway": []any{ + map[string]any{ + "address": Address(""), + "network": "", + "authregistrysvc": "localhost:19000", + }, + }, + "authregistry": []any{ + map[string]any{ + "address": Address("localhost:19000"), + "network": "", + "driver": "static", + "drivers": map[string]any{ + "static": map[string]any{ + "rules": map[string]any{ + "basic": "localhost:19001", + "machine": "localhost:19002", + }, + }, + }, + }, + }, + "authprovider": []any{ + map[string]any{ + "address": "localhost:19001", + "network": "", + "driver": "ldap", + "drivers": map[string]any{ + "ldap": map[string]any{ + "password": "ldap", + }, + }, + }, + map[string]any{ + "address": "localhost:19002", + "network": "", + "driver": "machine", + "drivers": map[string]any{ + "machine": map[string]any{ + "api_key": "secretapikey", + }, + }, + }, + }, + }, + }, + "http": map[string]any{ + "network": "", + "address": Address("localhost:19003"), + "certfile": "", + "keyfile": "", + "middlewares": map[string]any{}, + "services": map[string]any{ + "dataprovider": []any{ + map[string]any{ + "address": Address("localhost:19003"), + "network": "", + "driver": "localhome", + }, + }, + "sysinfo": []any{ + map[string]any{ + "address": Address("localhost:19003"), + "network": "", + }, + }, + }, + }, + "serverless": map[string]any{ + "services": map[string]any{ + "notifications": map[string]any{ + "nats_address": "nats-server-01.example.com", + "nats_token": "secret-token-example", + }, + }, + }, + }, m) +} diff --git a/cmd/revad/pkg/config/dump.go b/cmd/revad/pkg/config/dump.go new file mode 100644 index 0000000000..e819e5eeb4 --- /dev/null +++ b/cmd/revad/pkg/config/dump.go @@ -0,0 +1,146 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import "reflect" + +func dumpStruct(v reflect.Value) map[string]any { + if v.Kind() != reflect.Struct { + panic("called dumpStruct on non struct type") + } + + n := v.NumField() + m := make(map[string]any, n) + + t := v.Type() + for i := 0; i < n; i++ { + e := v.Field(i) + f := t.Field(i) + + if !f.IsExported() { + continue + } + + if isFieldSquashed(f) { + if e.Kind() == reflect.Pointer { + e = e.Elem() + } + + var mm map[string]any + switch e.Kind() { + case reflect.Struct: + mm = dumpStruct(e) + case reflect.Map: + mm = dumpMap(e) + default: + panic("squash not allowed on non map/struct types") + } + for k, v := range mm { + m[k] = v + } + continue + } + + n := fieldName(f) + if n == "-" { + continue + } + + m[n] = dumpByType(e) + } + return m +} + +func fieldName(f reflect.StructField) string { + fromtag := f.Tag.Get("key") + if fromtag != "" { + return fromtag + } + return f.Name +} + +func isFieldSquashed(f reflect.StructField) bool { + tag := f.Tag.Get("key") + return tag != "" && tag[1:] == "squash" +} + +func dumpMap(v reflect.Value) map[string]any { + if v.Kind() != reflect.Map { + panic("called dumpMap on non map type") + } + + m := make(map[string]any, v.Len()) + iter := v.MapRange() + for iter.Next() { + k := iter.Key() + e := iter.Value() + + key, ok := k.Interface().(string) + if !ok { + panic("key map must be a string") + } + + m[key] = dumpByType(e) + } + return m +} + +func dumpList(v reflect.Value) []any { + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + panic("called dumpList on non array/slice type") + } + + n := v.Len() + l := make([]any, 0, n) + + for i := 0; i < n; i++ { + e := v.Index(i) + l = append(l, dumpByType(e)) + } + return l +} + +func dumpPrimitive(v reflect.Value) any { + if v.Kind() != reflect.Bool && v.Kind() != reflect.Int && v.Kind() != reflect.Int8 && + v.Kind() != reflect.Int16 && v.Kind() != reflect.Int32 && v.Kind() != reflect.Int64 && + v.Kind() != reflect.Uint && v.Kind() != reflect.Uint8 && v.Kind() != reflect.Uint16 && + v.Kind() != reflect.Uint32 && v.Kind() != reflect.Uint64 && v.Kind() != reflect.Float32 && + v.Kind() != reflect.Float64 && v.Kind() != reflect.String { + panic("called dumpPrimitive on non primitive type: " + v.Kind().String()) + } + return v.Interface() +} + +func dumpByType(v reflect.Value) any { + switch v.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.String: + return dumpPrimitive(v) + case reflect.Array, reflect.Slice: + return dumpList(v) + case reflect.Struct: + return dumpStruct(v) + case reflect.Map: + return dumpMap(v) + case reflect.Interface, reflect.Pointer: + return dumpByType(v.Elem()) + } + panic("type not supported: " + v.Kind().String()) +} diff --git a/cmd/revad/pkg/config/dump_test.go b/cmd/revad/pkg/config/dump_test.go new file mode 100644 index 0000000000..cc214afab8 --- /dev/null +++ b/cmd/revad/pkg/config/dump_test.go @@ -0,0 +1,247 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDumpMap(t *testing.T) { + tests := []struct { + in map[string]any + exp map[string]any + }{ + { + in: map[string]any{}, + exp: map[string]any{}, + }, + { + in: map[string]any{ + "simple": SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + }, + exp: map[string]any{ + "simple": map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + }, + { + in: map[string]any{ + "simple": SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + "map": map[string]any{ + "mapa": "value_mapa", + "mapb": "value_mapb", + }, + }, + exp: map[string]any{ + "simple": map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + "map": map[string]any{ + "mapa": "value_mapa", + "mapb": "value_mapb", + }, + }, + }, + } + + for _, tt := range tests { + m := dumpMap(reflect.ValueOf(tt.in)) + assert.Equal(t, m, tt.exp) + } +} + +func TestDumpList(t *testing.T) { + tests := []struct { + in []any + exp []any + }{ + { + in: []any{}, + exp: []any{}, + }, + { + in: []any{1, 2, 3, 4}, + exp: []any{1, 2, 3, 4}, + }, + { + in: []any{ + map[string]any{ + "map": SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + }, + 5, + SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + }, + exp: []any{ + map[string]any{ + "map": map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + 5, + map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + }, + } + + for _, tt := range tests { + l := dumpList(reflect.ValueOf(tt.in)) + assert.Equal(t, l, tt.exp) + } +} + +func TestDumpStruct(t *testing.T) { + tests := []struct { + in any + exp map[string]any + }{ + { + in: SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + exp: map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + { + in: NestedStruct{ + Nested: SimpleStruct{ + KeyA: "value_a", + KeyB: "value_b", + }, + Value: 12, + }, + exp: map[string]any{ + "nested": map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + "value": 12, + }, + }, + { + in: StructWithNestedMap{ + Map: map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + exp: map[string]any{ + "map": map[string]any{ + "keya": "value_a", + "keyb": "value_b", + }, + }, + }, + { + in: StructWithNestedList{ + List: []SimpleStruct{ + { + KeyA: "value_a[1]", + KeyB: "value_b[1]", + }, + { + KeyA: "value_a[2]", + KeyB: "value_b[2]", + }, + }, + }, + exp: map[string]any{ + "list": []any{ + map[string]any{ + "keya": "value_a[1]", + "keyb": "value_b[1]", + }, + map[string]any{ + "keya": "value_a[2]", + "keyb": "value_b[2]", + }, + }, + }, + }, + { + in: Squashed{ + Squashed: SimpleStruct{ + KeyA: "value_a[1]", + KeyB: "value_b[1]", + }, + Simple: SimpleStruct{ + KeyA: "value_a[2]", + KeyB: "value_b[2]", + }, + }, + exp: map[string]any{ + "keya": "value_a[1]", + "keyb": "value_b[1]", + "Simple": map[string]any{ + "keya": "value_a[2]", + "keyb": "value_b[2]", + }, + }, + }, + { + in: SquashedMap{ + Squashed: map[string]any{ + "keya": "val_a[1]", + "keyb": "val_b[1]", + }, + Simple: SimpleStruct{ + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + exp: map[string]any{ + "keya": "val_a[1]", + "keyb": "val_b[1]", + "simple": map[string]any{ + "keya": "val_a[2]", + "keyb": "val_b[2]", + }, + }, + }, + } + + for _, tt := range tests { + s := dumpStruct(reflect.ValueOf(tt.in)) + assert.Equal(t, tt.exp, s) + } +} diff --git a/cmd/revad/pkg/config/grpc.go b/cmd/revad/pkg/config/grpc.go new file mode 100644 index 0000000000..26c67bea3d --- /dev/null +++ b/cmd/revad/pkg/config/grpc.go @@ -0,0 +1,77 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" +) + +// GRPC holds the configuration for the GRPC services. +type GRPC struct { + Address Address `mapstructure:"address" key:"address"` + Network string `mapstructure:"network" key:"network" default:"tcp"` + ShutdownDeadline int `mapstructure:"shutdown_deadline" key:"shutdown_deadline"` + EnableReflection bool `mapstructure:"enable_reflection" key:"enable_reflection"` + + Services map[string]ServicesConfig `mapstructure:"-" key:"services"` + Interceptors map[string]map[string]any `mapstructure:"-" key:"interceptors"` + + iterableImpl +} + +func (g *GRPC) services() map[string]ServicesConfig { return g.Services } +func (g *GRPC) interceptors() map[string]map[string]any { return g.Interceptors } + +func (c *Config) parseGRPC(raw map[string]any) error { + cfg, ok := raw["grpc"] + if !ok { + return nil + } + if err := mapstructure.Decode(cfg, c.GRPC); err != nil { + return errors.Wrap(err, "config: error decoding grpc config") + } + + cfgGRPC, ok := cfg.(map[string]any) + if !ok { + return errors.New("grpc must be a map") + } + + services, err := parseServices("grpc", cfgGRPC) + if err != nil { + return err + } + + interceptors, err := parseMiddlwares(cfgGRPC, "interceptors") + if err != nil { + return err + } + + c.GRPC.Services = services + c.GRPC.Interceptors = interceptors + c.GRPC.iterableImpl = iterableImpl{c.GRPC} + + for _, svc := range c.GRPC.Services { + for _, cfg := range svc { + cfg.Address = addressForService(c.GRPC.Address, cfg.Config) + cfg.Network = networkForService(c.HTTP.Network, cfg.Config) + } + } + return nil +} diff --git a/cmd/revad/pkg/config/http.go b/cmd/revad/pkg/config/http.go new file mode 100644 index 0000000000..e29d02e089 --- /dev/null +++ b/cmd/revad/pkg/config/http.go @@ -0,0 +1,77 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" +) + +// HTTP holds the configuration for the HTTP services. +type HTTP struct { + Network string `mapstructure:"network" key:"network" default:"tcp"` + Address Address `mapstructure:"address" key:"address"` + CertFile string `mapstructure:"certfile" key:"certfile"` + KeyFile string `mapstructure:"keyfile" key:"keyfile"` + + Services map[string]ServicesConfig `mapstructure:"-" key:"services"` + Middlewares map[string]map[string]any `mapstructure:"-" key:"middlewares"` + + iterableImpl +} + +func (h *HTTP) services() map[string]ServicesConfig { return h.Services } +func (h *HTTP) interceptors() map[string]map[string]any { return h.Middlewares } + +func (c *Config) parseHTTP(raw map[string]any) error { + cfg, ok := raw["http"] + if !ok { + return nil + } + if err := mapstructure.Decode(cfg, c.HTTP); err != nil { + return errors.Wrap(err, "config: error decoding http config") + } + + cfgHTTP, ok := cfg.(map[string]any) + if !ok { + return errors.New("http must be a map") + } + + services, err := parseServices("http", cfgHTTP) + if err != nil { + return err + } + + middlewares, err := parseMiddlwares(cfgHTTP, "middlewares") + if err != nil { + return err + } + + c.HTTP.Services = services + c.HTTP.Middlewares = middlewares + c.HTTP.iterableImpl = iterableImpl{c.HTTP} + + for _, svc := range c.HTTP.Services { + for _, cfg := range svc { + cfg.Address = addressForService(c.HTTP.Address, cfg.Config) + cfg.Network = networkForService(c.HTTP.Network, cfg.Config) + } + } + return nil +} diff --git a/cmd/revad/pkg/config/lookup.go b/cmd/revad/pkg/config/lookup.go new file mode 100644 index 0000000000..5edb2a98d6 --- /dev/null +++ b/cmd/revad/pkg/config/lookup.go @@ -0,0 +1,257 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "io" + "reflect" + + "github.com/pkg/errors" +) + +// ErrKeyNotFound is the error returned when a key does not exist +// in the configuration. +type ErrKeyNotFound struct { + Key string +} + +// Error returns a string representation of the ErrKeyNotFound error. +func (e ErrKeyNotFound) Error() string { + return "key '" + e.Key + "' not found in the configuration" +} + +// lookupStruct recursively looks up the key in the struct v. +// It panics if the value in v is not a struct. +// Only fields are allowed to be accessed. It bails out if +// an user wants to access by index. +// The struct is traversed considering the field tags. If the tag +// "key" is not specified for a field, the field is skipped in +// the lookup. If the tag specifies "squash", the field is treated +// as squashed. +func lookupStruct(key string, v reflect.Value) (any, error) { + if v.Kind() != reflect.Struct { + panic("called lookupStruct on non struct type") + } + + cmd, next, err := parseNext(key) + if errors.Is(err, io.EOF) { + return v.Interface(), nil + } + if err != nil { + return nil, err + } + + c, ok := cmd.(FieldByKey) + if !ok { + return nil, errors.New("call of index on struct type") + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + val := v.Field(i) + field := t.Field(i) + + if !field.IsExported() { + continue + } + + tag := field.Tag.Get("key") + if tag == "" { + continue + } + + if tag[1:] == "squash" { + if val.Kind() == reflect.Pointer { + val = val.Elem() + } + + var ( + v any + err error + ) + switch val.Kind() { + case reflect.Struct: + v, err = lookupStruct(key, val) + case reflect.Map: + v, err = lookupMap(key, val) + default: + panic("squash not allowed on non map/struct types") + } + var e ErrKeyNotFound + if errors.As(err, &e) { + continue + } + if err != nil { + return nil, err + } + return v, nil + } + + if tag != c.Key { + continue + } + + return lookupByType(next, val) + } + return nil, ErrKeyNotFound{Key: key} +} + +var typeLookuper = reflect.TypeOf((*Lookuper)(nil)).Elem() + +// lookupByType recursively looks up the given key in v. +func lookupByType(key string, v reflect.Value) (any, error) { + if v.Type().Implements(typeLookuper) { + if v, err := lookupFromLookuper(key, v); err == nil && v != nil { + return v, nil + } + } + switch v.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.String: + return lookupPrimitive(key, v) + case reflect.Array, reflect.Slice: + return lookupList(key, v) + case reflect.Struct: + return lookupStruct(key, v) + case reflect.Map: + return lookupMap(key, v) + case reflect.Interface, reflect.Pointer: + return lookupByType(key, v.Elem()) + } + panic("type not supported: " + v.Kind().String()) +} + +// lookupFromLookuper looks up the key in a Lookup value. +func lookupFromLookuper(key string, v reflect.Value) (any, error) { + g, ok := v.Interface().(Lookuper) + if !ok { + panic("called lookupFromLookuper on type not implementing Lookup interface") + } + + cmd, _, err := parseNext(key) + if errors.Is(err, io.EOF) { + return v.Interface(), nil + } + if err != nil { + return nil, err + } + + c, ok := cmd.(FieldByKey) + if !ok { + return nil, errors.New("call of index on getter type") + } + + return g.Lookup(c.Key) +} + +// lookupMap recursively looks up the given key in the map v. +// It panics if the value in v is not a map. +// Works similarly to lookupStruct. +func lookupMap(key string, v reflect.Value) (any, error) { + if v.Kind() != reflect.Map { + panic("called lookupMap on non map type") + } + + cmd, next, err := parseNext(key) + if errors.Is(err, io.EOF) { + return v.Interface(), nil + } + if err != nil { + return nil, err + } + + c, ok := cmd.(FieldByKey) + if !ok { + return nil, errors.New("call of index on map type") + } + + // lookup elemen in the map + el := v.MapIndex(reflect.ValueOf(c.Key)) + if !el.IsValid() { + return nil, ErrKeyNotFound{Key: key} + } + + return lookupByType(next, el) +} + +// lookupList recursively looks up the given key in the list v, +// in all the elements contained in the list. +// It panics if the value v is not a list. +// The elements can be addressed in general by index, but +// access by key is only allowed if the list contains exactly +// one element. +func lookupList(key string, v reflect.Value) (any, error) { + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + panic("called lookupList on non array/slice type") + } + + cmd, next, err := parseNext(key) + if errors.Is(err, io.EOF) { + return v.Interface(), nil + } + if err != nil { + return nil, err + } + + var el reflect.Value + switch c := cmd.(type) { + case FieldByIndex: + if c.Index < 0 || c.Index >= v.Len() { + return nil, errors.New("list index out of range") + } + el = v.Index(c.Index) + case FieldByKey: + // only allowed if the list contains only one element + if v.Len() != 1 { + return nil, errors.New("cannot access field by key on a non 1-elem list") + } + el = v.Index(0) + e, err := lookupByType("."+c.Key, el) + if err != nil { + return nil, err + } + el = reflect.ValueOf(e) + } + + return lookupByType(next, el) +} + +// lookupPrimitive gets the value from v. +// If the key tries to access by field or by index the value, +// an error is returned. +func lookupPrimitive(key string, v reflect.Value) (any, error) { + if v.Kind() != reflect.Bool && v.Kind() != reflect.Int && v.Kind() != reflect.Int8 && + v.Kind() != reflect.Int16 && v.Kind() != reflect.Int32 && v.Kind() != reflect.Int64 && + v.Kind() != reflect.Uint && v.Kind() != reflect.Uint8 && v.Kind() != reflect.Uint16 && + v.Kind() != reflect.Uint32 && v.Kind() != reflect.Uint64 && v.Kind() != reflect.Float32 && + v.Kind() != reflect.Float64 && v.Kind() != reflect.String { + panic("called lookupPrimitive on non primitive type: " + v.Kind().String()) + } + + _, _, err := parseNext(key) + if errors.Is(err, io.EOF) { + return v.Interface(), nil + } + if err != nil { + return nil, err + } + + return nil, errors.New("cannot address a value of type " + v.Kind().String()) +} diff --git a/cmd/revad/pkg/config/lookup_test.go b/cmd/revad/pkg/config/lookup_test.go new file mode 100644 index 0000000000..dcd4e1b9bb --- /dev/null +++ b/cmd/revad/pkg/config/lookup_test.go @@ -0,0 +1,198 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type SimpleStruct struct { + KeyA string `key:"keya"` + KeyB string `key:"keyb"` +} + +type NestedStruct struct { + Nested SimpleStruct `key:"nested"` + Value int `key:"value"` +} + +type StructWithNestedMap struct { + Map map[string]any `key:"map"` +} + +type StructWithNestedList struct { + List []SimpleStruct `key:"list"` +} + +type Squashed struct { + Squashed SimpleStruct `key:",squash"` + Simple SimpleStruct +} + +type SquashedMap struct { + Squashed map[string]any `key:",squash"` + Simple SimpleStruct `key:"simple"` +} + +type StructWithAddress struct { + Address Address `key:"address"` +} + +func TestLookupStruct(t *testing.T) { + tests := []struct { + in any + key string + val any + err error + }{ + { + in: SimpleStruct{ + KeyA: "val_a", + KeyB: "val_b", + }, + key: ".keyb", + val: "val_b", + }, + { + in: NestedStruct{ + Nested: SimpleStruct{ + KeyA: "val_a", + KeyB: "val_b", + }, + Value: 10, + }, + key: ".nested.keyb", + val: "val_b", + }, + { + in: NestedStruct{ + Nested: SimpleStruct{ + KeyA: "val_a", + KeyB: "val_b", + }, + Value: 10, + }, + key: ".value", + val: 10, + }, + { + in: StructWithNestedMap{ + Map: map[string]any{ + "key1": "val1", + "key2": "val2", + }, + }, + key: ".map.key1", + val: "val1", + }, + { + in: StructWithNestedList{ + List: []SimpleStruct{ + { + KeyA: "val_a[1]", + KeyB: "val_b[1]", + }, + { + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + }, + key: ".list[1].keyb", + val: "val_b[2]", + }, + { + in: StructWithNestedList{ + List: []SimpleStruct{ + { + KeyA: "val_a[1]", + KeyB: "val_b[1]", + }, + }, + }, + key: ".list.keya", + val: "val_a[1]", + }, + { + in: StructWithNestedList{ + List: []SimpleStruct{ + { + KeyA: "val_a[1]", + KeyB: "val_b[1]", + }, + { + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + }, + key: ".list[1]", + val: SimpleStruct{ + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + { + in: Squashed{ + Squashed: SimpleStruct{ + KeyA: "val_a[1]", + KeyB: "val_b[1]", + }, + Simple: SimpleStruct{ + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + key: ".keya", + val: "val_a[1]", + }, + { + in: SquashedMap{ + Squashed: map[string]any{ + "keya": "val_a[1]", + "keyb": "val_b[1]", + }, + Simple: SimpleStruct{ + KeyA: "val_a[2]", + KeyB: "val_b[2]", + }, + }, + key: ".keya", + val: "val_a[1]", + }, + { + in: StructWithAddress{ + Address: "188.184.37.219:9142", + }, + key: ".address.port", + val: 9142, + }, + } + + for _, tt := range tests { + got, err := lookupStruct(tt.key, reflect.ValueOf(tt.in)) + assert.Equal(t, err, tt.err, "got not expected error") + if tt.err == nil { + assert.Equal(t, tt.val, got) + } + } +} diff --git a/cmd/revad/pkg/config/parser.go b/cmd/revad/pkg/config/parser.go new file mode 100644 index 0000000000..aa25de6e70 --- /dev/null +++ b/cmd/revad/pkg/config/parser.go @@ -0,0 +1,100 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "io" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// Command is the command to execute after parsing the template. +type Command interface{ isCommand() } + +// FieldByKey instructs the template runner to get a field by a key. +type FieldByKey struct{ Key string } + +func (FieldByKey) isCommand() {} + +// FieldByIndex instructs the template runner to get a field by an index. +type FieldByIndex struct{ Index int } + +func (FieldByIndex) isCommand() {} + +// parseNext reads the next token from the key and +// assings a command. +// If the key is empty io.EOF is returned. +func parseNext(key string) (Command, string, error) { + // key = ".grpc.services.authprovider[1].address" + + key = strings.TrimSpace(key) + + // first character must be either "." or "[" + // unless the key is empty + if key == "" { + return nil, "", io.EOF + } + + switch { + case strings.HasPrefix(key, "."): + tkn, next := split(key) + return FieldByKey{Key: tkn}, next, nil + case strings.HasPrefix(key, "["): + tkn, next := split(key) + index, err := strconv.ParseInt(tkn, 10, 64) + if err != nil { + return nil, "", errors.Wrap(err, "parsing error") + } + return FieldByIndex{Index: int(index)}, next, nil + } + + return nil, "", errors.New("parsing error: operator not recognised in key " + key) +} + +func split(key string) (token string, next string) { + // key = ".grpc.services.authprovider[1].address" + // -> grpc + // key = "[].address" + // -> + if key == "" { + return + } + + i := -1 + s := key[0] + key = key[1:] + + switch s { + case '.': + i = strings.IndexAny(key, ".[") + case '[': + i = strings.IndexByte(key, ']') + } + + if i == -1 { + return key, "" + } + + if key[i] == ']' { + return key[:i], key[i+1:] + } + return key[:i], key[i:] +} diff --git a/cmd/revad/pkg/config/parser_test.go b/cmd/revad/pkg/config/parser_test.go new file mode 100644 index 0000000000..87ba592aec --- /dev/null +++ b/cmd/revad/pkg/config/parser_test.go @@ -0,0 +1,97 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "testing" + + "gotest.tools/assert" +) + +func TestSplit(t *testing.T) { + tests := []struct { + key string + token string + next string + }{ + { + key: ".grpc.services.authprovider[1].address", + token: "grpc", + next: ".services.authprovider[1].address", + }, + { + key: "[1].address", + token: "1", + next: ".address", + }, + { + key: "[100].address", + token: "100", + next: ".address", + }, + { + key: "", + }, + } + + for _, tt := range tests { + token, next := split(tt.key) + assert.Equal(t, token, tt.token) + assert.Equal(t, next, tt.next) + } +} + +func TestParseNext(t *testing.T) { + tests := []struct { + key string + cmd Command + next string + err error + }{ + { + key: ".grpc.services.authprovider[1].address", + cmd: FieldByKey{Key: "grpc"}, + next: ".services.authprovider[1].address", + }, + { + key: ".authprovider[1].address", + cmd: FieldByKey{Key: "authprovider"}, + next: "[1].address", + }, + { + key: "[1].authprovider.address", + cmd: FieldByIndex{Index: 1}, + next: ".authprovider.address", + }, + { + key: ".authprovider", + cmd: FieldByKey{Key: "authprovider"}, + next: "", + }, + } + + for _, tt := range tests { + cmd, next, err := parseNext(tt.key) + assert.Equal(t, err, tt.err) + if tt.err == nil { + assert.Equal(t, cmd, tt.cmd) + assert.Equal(t, next, tt.next) + } + } +} diff --git a/cmd/revad/pkg/config/serverless.go b/cmd/revad/pkg/config/serverless.go new file mode 100644 index 0000000000..ea0c574d34 --- /dev/null +++ b/cmd/revad/pkg/config/serverless.go @@ -0,0 +1,54 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" +) + +// Serverless holds the configuration for the serverless services. +type Serverless struct { + Services map[string]map[string]any `key:"services" mapstructure:"services"` +} + +func (c *Config) parseServerless(raw map[string]any) error { + cfg, ok := raw["serverless"] + if !ok { + return nil + } + + var s Serverless + if err := mapstructure.Decode(cfg, &s); err != nil { + return errors.Wrap(err, "config: error decoding serverless config") + } + + c.Serverless = &s + return nil +} + +// ForEach iterates to each service calling the function f. +func (s *Serverless) ForEach(f func(name string, config map[string]any) error) error { + for name, cfg := range s.Services { + if err := f(name, cfg); err != nil { + return err + } + } + return nil +} diff --git a/cmd/revad/pkg/config/templates.go b/cmd/revad/pkg/config/templates.go new file mode 100644 index 0000000000..d4d3e4ae52 --- /dev/null +++ b/cmd/revad/pkg/config/templates.go @@ -0,0 +1,279 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +// applyTemplateStruct applies recursively to all its fields all the template +// strings to the struct v. +// It panics if the value is not a struct. +// A field in the struct is skipped for applying all the templates +// if a tag "template" has the value "-". +func applyTemplateStruct(l Lookuper, p setter, v reflect.Value) error { + if v.Kind() != reflect.Struct { + panic("called applyTemplateStruct on non struct type") + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + el := v.Field(i) + f := t.Field(i) + + if !f.IsExported() { + continue + } + + if f.Tag.Get("template") == "-" { + // skip this field + continue + } + + if err := applyTemplateByType(l, setterStruct{Struct: v, Field: i}, el); err != nil { + return err + } + } + return nil +} + +// applyTemplateByType applies the template string to a generic type. +func applyTemplateByType(l Lookuper, p setter, v reflect.Value) error { + switch v.Kind() { + case reflect.Array, reflect.Slice: + return applyTemplateList(l, p, v) + case reflect.Struct: + return applyTemplateStruct(l, p, v) + case reflect.Map: + return applyTemplateMap(l, p, v) + case reflect.Interface: + return applyTemplateInterface(l, p, v) + case reflect.String: + return applyTemplateString(l, p, v) + case reflect.Pointer: + return applyTemplateByType(l, p, v.Elem()) + } + return nil +} + +// applyTemplateList recursively applies in all the elements of the list +// the template strings. +// It panics if the given value is not a list. +func applyTemplateList(l Lookuper, p setter, v reflect.Value) error { + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + panic("called applyTemplateList on non array/slice type") + } + + for i := 0; i < v.Len(); i++ { + el := v.Index(i) + if err := applyTemplateByType(l, setterList{List: v, Index: i}, el); err != nil { + return err + } + } + return nil +} + +// applyTemplateMap recursively applies in all the elements of the map +// the template strings. +// It panics if the given value is not a map. +func applyTemplateMap(l Lookuper, p setter, v reflect.Value) error { + if v.Kind() != reflect.Map { + panic("called applyTemplateMap on non map type") + } + + iter := v.MapRange() + for iter.Next() { + k := iter.Key() + el := v.MapIndex(k) + if err := applyTemplateByType(l, setterMap{Map: v, Key: k.Interface()}, el); err != nil { + return err + } + } + return nil +} + +// applyTemplateString applies to the string the template string, if any. +// It panics if the given value is not a string. +func applyTemplateString(l Lookuper, p setter, v reflect.Value) error { + if v.Kind() != reflect.String { + panic("called applyTemplateString on non string type") + } + s := v.String() + tmpl, is := isTemplate(s) + if !is { + // nothing to do + return nil + } + + key := keyFromTemplate(tmpl) + val, err := l.Lookup(key) + if err != nil { + return err + } + if val == nil { + return nil + } + + new, err := replaceTemplate(s, tmpl, val) + if err != nil { + return err + } + str, ok := convertToString(new) + if !ok { + return fmt.Errorf("value %v cannot be converted as string in the template %s", val, new) + } + + p.SetValue(str) + return nil +} + +// applyTemplateInterface applies to the interface the template string, if any. +// It panics if the given value is not an interface. +func applyTemplateInterface(l Lookuper, p setter, v reflect.Value) error { + if v.Kind() != reflect.Interface { + panic("called applyTemplateInterface on non interface value") + } + + s, ok := v.Interface().(string) + if !ok { + return applyTemplateByType(l, p, v.Elem()) + } + + tmpl, is := isTemplate(s) + if !is { + // nothing to do + return nil + } + + key := keyFromTemplate(tmpl) + val, err := l.Lookup(key) + if err != nil { + return err + } + if val == nil { + return nil + } + + new, err := replaceTemplate(s, tmpl, val) + if err != nil { + return err + } + p.SetValue(new) + return nil +} + +func replaceTemplate(original, tmpl string, val any) (any, error) { + if strings.TrimSpace(original) == tmpl { + // the value was directly a template, i.e. "{{ grpc.services.gateway.address }}" + return val, nil + } + // the value is of something like "something {{ template }} something else" + // in this case we need to replace the template string with the value, converted + // as string in the original val + s, ok := convertToString(val) + if !ok { + return nil, fmt.Errorf("value %v cannot be converted as string in the template %s", val, original) + } + return strings.Replace(original, tmpl, s, 1), nil +} + +func convertToString(val any) (string, bool) { + switch v := val.(type) { + case string: + return v, true + case fmt.Stringer: + return v.String(), true + case int: + return strconv.FormatInt(int64(v), 10), true + case int8: + return strconv.FormatInt(int64(v), 10), true + case int16: + return strconv.FormatInt(int64(v), 10), true + case int32: + return strconv.FormatInt(int64(v), 10), true + case uint: + return strconv.FormatUint(uint64(v), 10), true + case uint8: + return strconv.FormatUint(uint64(v), 10), true + case uint16: + return strconv.FormatUint(uint64(v), 10), true + case uint32: + return strconv.FormatUint(uint64(v), 10), true + case uint64: + return strconv.FormatUint(v, 10), true + case bool: + return strconv.FormatBool(v), true + } + return "", false +} + +var templateRegex = regexp.MustCompile("{{.{1,}}}") + +func isTemplate(s string) (string, bool) { + m := templateRegex.FindString(s) + return m, m != "" +} + +func keyFromTemplate(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "{{") + s = strings.TrimSuffix(s, "}}") + return "." + strings.TrimSpace(s) +} + +type setter interface { + // SetValue sets the value v in a container. + SetValue(v any) +} + +type setterList struct { + List reflect.Value + Index int +} + +type setterMap struct { + Map reflect.Value + Key any +} + +type setterStruct struct { + Struct reflect.Value + Field int +} + +// SetValue sets the value v in the element +// of the list. +func (s setterList) SetValue(v any) { + el := s.List.Index(s.Index) + el.Set(reflect.ValueOf(v)) +} + +// SetValue sets the value v to the element of the map. +func (s setterMap) SetValue(v any) { + s.Map.SetMapIndex(reflect.ValueOf(s.Key), reflect.ValueOf(v)) +} + +// SetValue sets the value v to the field in the struct. +func (s setterStruct) SetValue(v any) { + s.Struct.Field(s.Field).Set(reflect.ValueOf(v)) +} diff --git a/cmd/revad/pkg/config/templates_test.go b/cmd/revad/pkg/config/templates_test.go new file mode 100644 index 0000000000..9db0167bec --- /dev/null +++ b/cmd/revad/pkg/config/templates_test.go @@ -0,0 +1,135 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestApplyTemplate(t *testing.T) { + cfg1 := &Config{ + GRPC: &GRPC{ + Services: map[string]ServicesConfig{ + "authprovider": { + { + Address: "localhost:1900", + }, + }, + "authregistry": { + { + Address: "localhost:1901", + Config: map[string]any{ + "drivers": map[string]any{ + "static": map[string]any{ + "demo": "{{ grpc.services.authprovider.address }}", + }, + }, + }, + }, + }, + "other": { + { + Address: "localhost:1902", + Config: map[string]any{ + "drivers": map[string]any{ + "static": map[string]any{ + "demo": "https://{{ grpc.services.authprovider.address }}/data", + }, + }, + }, + }, + }, + "port": { + { + + Config: map[string]any{ + "drivers": map[string]any{ + "static": map[string]any{ + "demo": "https://cern.ch:{{ grpc.services.authprovider.address.port }}/data", + }, + }, + }, + }, + }, + }, + }, + } + err := cfg1.ApplyTemplates(cfg1) + assert.ErrorIs(t, err, nil) + assert.Equal(t, Address("localhost:1900"), cfg1.GRPC.Services["authregistry"][0].Config["drivers"].(map[string]any)["static"].(map[string]any)["demo"]) + assert.Equal(t, "https://localhost:1900/data", cfg1.GRPC.Services["other"][0].Config["drivers"].(map[string]any)["static"].(map[string]any)["demo"]) + assert.Equal(t, "https://cern.ch:1900/data", cfg1.GRPC.Services["port"][0].Config["drivers"].(map[string]any)["static"].(map[string]any)["demo"]) + + cfg2 := &Config{ + Shared: &Shared{ + GatewaySVC: "{{ grpc.services.authregistry.address }}", + }, + Vars: Vars{ + "db_username": "root", + "db_password": "secretpassword", + "integer": 10, + }, + GRPC: &GRPC{ + Services: map[string]ServicesConfig{ + "authregistry": { + { + Address: "localhost:1901", + Config: map[string]any{ + "drivers": map[string]any{ + "sql": map[string]any{ + "db_username": "{{ vars.db_username }}", + "db_password": "{{ vars.db_password }}", + "key": "value", + "int": "{{ vars.integer }}", + }, + }, + }, + }, + }, + "other": { + { + Address: "localhost:1902", + Config: map[string]any{ + "drivers": map[string]any{ + "sql": map[string]any{ + "db_host": "http://localhost:{{ vars.integer }}", + }, + }, + }, + }, + }, + }, + }, + } + + err = cfg2.ApplyTemplates(cfg2) + assert.ErrorIs(t, err, nil) + assert.Equal(t, "localhost:1901", cfg2.Shared.GatewaySVC) + assert.Equal(t, map[string]any{ + "db_username": "root", + "db_password": "secretpassword", + "key": "value", + "int": 10, + }, cfg2.GRPC.Services["authregistry"][0].Config["drivers"].(map[string]any)["sql"]) + assert.Equal(t, map[string]any{ + "db_host": "http://localhost:10", + }, cfg2.GRPC.Services["other"][0].Config["drivers"].(map[string]any)["sql"]) +} diff --git a/cmd/revad/internal/grace/grace.go b/cmd/revad/pkg/grace/grace.go similarity index 68% rename from cmd/revad/internal/grace/grace.go rename to cmd/revad/pkg/grace/grace.go index 52cd3619a8..6640a40484 100644 --- a/cmd/revad/internal/grace/grace.go +++ b/cmd/revad/pkg/grace/grace.go @@ -29,6 +29,7 @@ import ( "syscall" "time" + netutil "github.com/cs3org/reva/pkg/utils/net" "github.com/pkg/errors" "github.com/rs/zerolog" ) @@ -40,12 +41,14 @@ type Watcher struct { graceful bool ppid int lns map[string]net.Listener - ss map[string]Server + ss []Server SL Serverless pidFile string childPIDs []int } +const revaEnvPrefix = "REVA_FD_" + // Option represent an option. type Option func(w *Watcher) @@ -69,7 +72,7 @@ func NewWatcher(opts ...Option) *Watcher { log: zerolog.Nop(), graceful: os.Getenv("GRACEFUL") == "true", ppid: os.Getppid(), - ss: map[string]Server{}, + ss: make([]Server, 0), } for _, opt := range opts { @@ -82,13 +85,18 @@ func NewWatcher(opts ...Option) *Watcher { // Exit exits the current process cleaning up // existing pid files. func (w *Watcher) Exit(errc int) { + w.Clean() + os.Exit(errc) +} + +// Clean cleans up existing pid files. +func (w *Watcher) Clean() { err := w.clean() if err != nil { w.log.Warn().Err(err).Msg("error removing pid file") } else { w.log.Info().Msgf("pid file %q got removed", w.pidFile) } - os.Exit(errc) } func (w *Watcher) clean() error { @@ -186,29 +194,122 @@ func newListener(network, addr string) (net.Listener, error) { return net.Listen(network, addr) } +// implements the net.Listener interface. +type inherited struct { + f *os.File + ln net.Listener +} + +func (i *inherited) Accept() (net.Conn, error) { + return i.ln.Accept() +} + +func (i *inherited) Close() error { + // TODO: improve this: if file close has error + // the listener is not closed + if err := i.f.Close(); err != nil { + return err + } + return i.ln.Close() +} + +func (i *inherited) Addr() net.Addr { + return i.ln.Addr() +} + +func inheritedListeners() map[string]net.Listener { + lns := make(map[string]net.Listener) + for _, val := range os.Environ() { + if strings.HasPrefix(val, revaEnvPrefix) { + // env variable is of type REVA_FD_= + env := strings.TrimPrefix(val, revaEnvPrefix) + s := strings.Split(env, "=") + if len(s) != 2 { + continue + } + svcname := strings.ToLower(s[0]) + fd, err := strconv.ParseUint(s[1], 10, 64) + if err != nil { + continue + } + f := os.NewFile(uintptr(fd), "") + ln, err := net.FileListener(f) + if err != nil { + // TODO: log error + continue + } + lns[svcname] = &inherited{f: f, ln: ln} + } + } + return lns +} + +func isRandomAddress(addr string) bool { + return addr == "" +} + +func getAddress(addr string) string { + if isRandomAddress(addr) { + return ":0" + } + return addr +} + // GetListeners return grpc listener first and http listener second. -func (w *Watcher) GetListeners(servers map[string]Server) (map[string]net.Listener, error) { - w.ss = servers - lns := map[string]net.Listener{} +func (w *Watcher) GetListeners(servers map[string]Addressable) (map[string]net.Listener, error) { + lns := make(map[string]net.Listener) + if w.graceful { - w.log.Info().Msg("graceful restart, inheriting parent ln fds for grpc and http") - count := 3 - for k, s := range servers { - network, addr := s.Network(), s.Address() - fd := os.NewFile(uintptr(count), "") // 3 because ExtraFile passed to new process - count++ - ln, err := net.FileListener(fd) - if err != nil { - w.log.Error().Err(err).Msg("error creating net.Listener from fd") - // create new fd - ln, err := newListener(network, addr) + w.log.Info().Msg("graceful restart, inheriting parent listener fds for grpc and http services") + + inherited := inheritedListeners() + logListeners(inherited, "inherited", &w.log) + + for svc, ln := range inherited { + addr, ok := servers[svc] + if !ok { + continue + } + // for services with random addresses, check and assign if available from inherited + // from the assigned addresses, assing the listener if address correspond + if isRandomAddress(addr.Address()) || + netutil.AddressEqual(ln.Addr(), addr.Network(), addr.Address()) { + lns[svc] = ln + } + } + + // close all the listeners not used from inherited + for svc, ln := range inherited { + if _, ok := lns[svc]; !ok { + w.log.Debug().Msgf("closing inherited listener %s:%s for service %s", ln.Addr().Network(), ln.Addr().String(), svc) + if err := ln.Close(); err != nil { + w.log.Error().Err(err).Msgf("error closing inherited listener %s:%s", ln.Addr().Network(), ln.Addr().String()) + return nil, errors.Wrap(err, "error closing inherited listener") + } + } + } + + var err error + // create assigned/random listeners for the missing services + for svc, a := range servers { + _, ok := lns[svc] + if ok { + continue + } + network, addr := a.Network(), getAddress(a.Address()) + // multiple services may have the same listener + ln, ok := get(lns, addr, network) + if !ok { + ln, err = newListener(network, addr) if err != nil { return nil, err } - lns[k] = ln - } else { - lns[k] = ln } + if err != nil { + w.log.Error().Err(err).Msgf("error getting listener on %s", addr) + return nil, errors.Wrap(err, "error getting listener") + } + lns[svc] = ln } // kill parent @@ -233,26 +334,55 @@ func (w *Watcher) GetListeners(servers map[string]Server) (map[string]net.Listen return lns, nil } - // create two listeners for grpc and http - for k, s := range servers { - network, addr := s.Network(), s.Address() - ln, err := newListener(network, addr) - if err != nil { - return nil, err + var err error + // no graceful + for svc, s := range servers { + network, addr := s.Network(), getAddress(s.Address()) + // multiple services may have the same listener + ln, ok := get(lns, addr, network) + if !ok { + ln, err = newListener(network, addr) + if err != nil { + return nil, err + } } - lns[k] = ln + w.log.Debug(). + Msgf("listener for %s assigned to %s:%s", svc, ln.Addr().Network(), ln.Addr().String()) + lns[svc] = ln } w.lns = lns return lns, nil } +func logListeners(lns map[string]net.Listener, info string, log *zerolog.Logger) { + r := make(map[string]string, len(lns)) + for n, ln := range lns { + r[n] = fmt.Sprintf("%s:%s", ln.Addr().Network(), ln.Addr().String()) + } + log.Debug().Interface(info, r).Send() +} + +func get(lns map[string]net.Listener, address, network string) (net.Listener, bool) { + for _, ln := range lns { + if netutil.AddressEqual(ln.Addr(), network, address) { + return ln, true + } + } + return nil, false +} + +// Addressable is the interface for exposing address info. +type Addressable interface { + Network() string + Address() string +} + // Server is the interface that servers like HTTP or gRPC // servers need to implement. type Server interface { - Stop() error - GracefulStop() error - Network() string - Address() string + Start(net.Listener) error + Serverless + Addressable } // Serverless is the interface that the serverless server implements. @@ -261,6 +391,12 @@ type Serverless interface { GracefulStop() error } +// SetServers sets the list of servers that have to be watched. +func (w *Watcher) SetServers(s []Server) { w.ss = s } + +// SetServerless sets the serverless that has to be watched. +func (w *Watcher) SetServerless(s Serverless) { w.SL = s } + // TrapSignals captures the OS signal. func (w *Watcher) TrapSignals() { signalCh := make(chan os.Signal, 1024) @@ -350,6 +486,8 @@ func (w *Watcher) TrapSignals() { func getListenerFile(ln net.Listener) (*os.File, error) { switch t := ln.(type) { + case *inherited: + return t.f, nil case *net.TCPListener: return t.File() case *net.UnixListener: @@ -361,13 +499,13 @@ func getListenerFile(ln net.Listener) (*os.File, error) { func forkChild(lns map[string]net.Listener) (*os.Process, error) { // Get the file descriptor for the listener and marshal the metadata to pass // to the child in the environment. - fds := map[string]*os.File{} - for k, ln := range lns { + fds := make(map[string]*os.File, 0) + for name, ln := range lns { fd, err := getListenerFile(ln) if err != nil { return nil, err } - fds[k] = fd + fds[name] = fd } // Pass stdin, stdout, and stderr along with the listener file to the child @@ -379,10 +517,10 @@ func forkChild(lns map[string]net.Listener) (*os.Process, error) { // Get current environment and add in the listener to it. environment := append(os.Environ(), "GRACEFUL=true") - var counter = 3 + counter := 3 for k, fd := range fds { k = strings.ToUpper(k) - environment = append(environment, k+"FD="+fmt.Sprintf("%d", counter)) + environment = append(environment, fmt.Sprintf("%s%s=%d", revaEnvPrefix, k, counter)) files = append(files, fd) counter++ } @@ -402,7 +540,7 @@ func forkChild(lns map[string]net.Listener) (*os.Process, error) { Sys: &syscall.SysProcAttr{}, }) - // TODO(labkode): if the process dies (because config changed and is wrong + // TODO(labkode): if the process dies (because config changed and is wrong) // we need to return an error if err != nil { return nil, err diff --git a/cmd/revad/runtime/grpc.go b/cmd/revad/runtime/grpc.go new file mode 100644 index 0000000000..5cd1b4b195 --- /dev/null +++ b/cmd/revad/runtime/grpc.go @@ -0,0 +1,150 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package runtime + +import ( + "sort" + + "github.com/cs3org/reva/internal/grpc/interceptors/appctx" + "github.com/cs3org/reva/internal/grpc/interceptors/auth" + "github.com/cs3org/reva/internal/grpc/interceptors/log" + "github.com/cs3org/reva/internal/grpc/interceptors/recovery" + "github.com/cs3org/reva/internal/grpc/interceptors/token" + "github.com/cs3org/reva/internal/grpc/interceptors/useragent" + "github.com/cs3org/reva/pkg/rgrpc" + rtrace "github.com/cs3org/reva/pkg/trace" + "github.com/pkg/errors" + "github.com/rs/zerolog" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" +) + +type unaryInterceptorTriple struct { + Name string + Priority int + Interceptor grpc.UnaryServerInterceptor +} + +type streamInterceptorTriple struct { + Name string + Priority int + Interceptor grpc.StreamServerInterceptor +} + +func initGRPCInterceptors(conf map[string]map[string]any, unprotected []string, logger *zerolog.Logger) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor, error) { + unaryTriples := []*unaryInterceptorTriple{} + for name, c := range conf { + new, ok := rgrpc.UnaryInterceptors[name] + if !ok { + continue + } + inter, prio, err := new(c) + if err != nil { + return nil, nil, errors.Wrap(err, "error creating unary interceptor: "+name) + } + triple := &unaryInterceptorTriple{ + Name: name, + Priority: prio, + Interceptor: inter, + } + unaryTriples = append(unaryTriples, triple) + } + + sort.SliceStable(unaryTriples, func(i, j int) bool { + return unaryTriples[i].Priority < unaryTriples[j].Priority + }) + + authUnary, err := auth.NewUnary(conf["auth"], unprotected) + if err != nil { + return nil, nil, errors.Wrap(err, "error creating unary auth interceptor") + } + + unaryInterceptors := []grpc.UnaryServerInterceptor{authUnary} + for _, t := range unaryTriples { + unaryInterceptors = append(unaryInterceptors, t.Interceptor) + logger.Info().Msgf("rgrpc: chaining grpc unary interceptor %s with priority %d", t.Name, t.Priority) + } + + unaryInterceptors = append(unaryInterceptors, + otelgrpc.UnaryServerInterceptor( + otelgrpc.WithTracerProvider(rtrace.Provider), + otelgrpc.WithPropagators(rtrace.Propagator)), + ) + + unaryInterceptors = append([]grpc.UnaryServerInterceptor{ + appctx.NewUnary(*logger), + token.NewUnary(), + useragent.NewUnary(), + log.NewUnary(), + recovery.NewUnary(), + }, unaryInterceptors...) + + streamTriples := []*streamInterceptorTriple{} + for name, c := range conf { + new, ok := rgrpc.StreamInterceptors[name] + if !ok { + continue + } + inter, prio, err := new(c) + if err != nil { + if err != nil { + return nil, nil, errors.Wrapf(err, "error creating streaming interceptor: %s,", name) + } + triple := &streamInterceptorTriple{ + Name: name, + Priority: prio, + Interceptor: inter, + } + streamTriples = append(streamTriples, triple) + } + } + // sort stream triples + sort.SliceStable(streamTriples, func(i, j int) bool { + return streamTriples[i].Priority < streamTriples[j].Priority + }) + + authStream, err := auth.NewStream(conf["auth"], unprotected) + if err != nil { + return nil, nil, errors.Wrap(err, "error creating stream auth interceptor") + } + + streamInterceptors := []grpc.StreamServerInterceptor{authStream} + for _, t := range streamTriples { + streamInterceptors = append(streamInterceptors, t.Interceptor) + logger.Info().Msgf("rgrpc: chaining grpc streaming interceptor %s with priority %d", t.Name, t.Priority) + } + + streamInterceptors = append([]grpc.StreamServerInterceptor{ + authStream, + appctx.NewStream(*logger), + token.NewStream(), + useragent.NewStream(), + log.NewStream(), + recovery.NewStream(), + }, streamInterceptors...) + + return unaryInterceptors, streamInterceptors, nil +} + +func grpcUnprotected(s map[string]rgrpc.Service) (unprotected []string) { + for _, svc := range s { + unprotected = append(unprotected, svc.UnprotectedEndpoints()...) + } + return +} diff --git a/cmd/revad/runtime/http.go b/cmd/revad/runtime/http.go new file mode 100644 index 0000000000..7568c174ea --- /dev/null +++ b/cmd/revad/runtime/http.go @@ -0,0 +1,88 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package runtime + +import ( + "path" + "sort" + + "github.com/cs3org/reva/internal/http/interceptors/appctx" + "github.com/cs3org/reva/internal/http/interceptors/auth" + "github.com/cs3org/reva/internal/http/interceptors/log" + "github.com/cs3org/reva/pkg/rhttp/global" + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// middlewareTriple represents a middleware with the +// priority to be chained. +type middlewareTriple struct { + Name string + Priority int + Middleware global.Middleware +} + +func initHTTPMiddlewares(conf map[string]map[string]any, unprotected []string, logger *zerolog.Logger) ([]global.Middleware, error) { + triples := []*middlewareTriple{} + for name, c := range conf { + new, ok := global.NewMiddlewares[name] + if !ok { + continue + } + m, prio, err := new(c) + if err != nil { + return nil, errors.Wrapf(err, "error creating new middleware: %s,", name) + } + triples = append(triples, &middlewareTriple{ + Name: name, + Priority: prio, + Middleware: m, + }) + logger.Info().Msgf("http middleware enabled: %s", name) + } + + sort.SliceStable(triples, func(i, j int) bool { + return triples[i].Priority > triples[j].Priority + }) + + authMiddle, err := auth.New(conf["auth"], unprotected) + if err != nil { + return nil, errors.Wrap(err, "rhttp: error creating auth middleware") + } + + middlewares := []global.Middleware{ + authMiddle, + log.New(), + appctx.New(*logger), + } + + for _, triple := range triples { + middlewares = append(middlewares, triple.Middleware) + } + return middlewares, nil +} + +func httpUnprotected(s map[string]global.Service) (unprotected []string) { + for _, svc := range s { + for _, url := range svc.Unprotected() { + unprotected = append(unprotected, path.Join("/", svc.Prefix(), url)) + } + } + return +} diff --git a/cmd/revad/runtime/option.go b/cmd/revad/runtime/option.go index 6a662e9fa9..dd49966424 100644 --- a/cmd/revad/runtime/option.go +++ b/cmd/revad/runtime/option.go @@ -30,11 +30,15 @@ type Option func(o *Options) type Options struct { Logger *zerolog.Logger Registry registry.Registry + PidFile string } // newOptions initializes the available default options. func newOptions(opts ...Option) Options { - opt := Options{} + l := zerolog.Nop() + opt := Options{ + Logger: &l, + } for _, o := range opts { o(&opt) @@ -50,6 +54,13 @@ func WithLogger(logger *zerolog.Logger) Option { } } +// WithPidFile sets to pidfile to use. +func WithPidFile(pidfile string) Option { + return func(o *Options) { + o.PidFile = pidfile + } +} + // WithRegistry provides a function to set the registry. func WithRegistry(r registry.Registry) Option { return func(o *Options) { diff --git a/cmd/revad/runtime/runtime.go b/cmd/revad/runtime/runtime.go index 255c55bdf8..3a4b601109 100644 --- a/cmd/revad/runtime/runtime.go +++ b/cmd/revad/runtime/runtime.go @@ -20,284 +20,295 @@ package runtime import ( "fmt" - "io" - "log" "net" - "os" "runtime" "strconv" "strings" - "github.com/cs3org/reva/cmd/revad/internal/grace" - "github.com/cs3org/reva/pkg/logger" - "github.com/cs3org/reva/pkg/registry/memory" + "github.com/pkg/errors" + + "github.com/cs3org/reva/cmd/revad/pkg/config" + "github.com/cs3org/reva/cmd/revad/pkg/grace" "github.com/cs3org/reva/pkg/rgrpc" "github.com/cs3org/reva/pkg/rhttp" + "github.com/cs3org/reva/pkg/rhttp/global" "github.com/cs3org/reva/pkg/rserverless" "github.com/cs3org/reva/pkg/sharedconf" rtrace "github.com/cs3org/reva/pkg/trace" - "github.com/cs3org/reva/pkg/utils" - "github.com/mitchellh/mapstructure" - "github.com/pkg/errors" + "github.com/cs3org/reva/pkg/utils/list" + "github.com/cs3org/reva/pkg/utils/maps" + netutil "github.com/cs3org/reva/pkg/utils/net" "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" ) -// Run runs a reva server with the given config file and pid file. -func Run(mainConf map[string]interface{}, pidFile, logLevel string) { - logConf := parseLogConfOrDie(mainConf["log"], logLevel) - logger := initLogger(logConf) - RunWithOptions(mainConf, pidFile, WithLogger(logger)) -} +// Reva represents a full instance of reva. +type Reva struct { + config *config.Config -// RunWithOptions runs a reva server with the given config file, pid file and options. -func RunWithOptions(mainConf map[string]interface{}, pidFile string, opts ...Option) { - options := newOptions(opts...) - parseSharedConfOrDie(mainConf["shared"]) - coreConf := parseCoreConfOrDie(mainConf["core"]) - - // TODO: one can pass the options from the config file to registry.New() and initialize a registry based upon config files. - if options.Registry != nil { - utils.GlobalRegistry = options.Registry - } else if _, ok := mainConf["registry"]; ok { - for _, services := range mainConf["registry"].(map[string]interface{}) { - for sName, nodes := range services.(map[string]interface{}) { - for _, instance := range nodes.([]interface{}) { - if err := utils.GlobalRegistry.Add(memory.NewService(sName, instance.(map[string]interface{})["nodes"].([]interface{}))); err != nil { - panic(err) - } - } - } - } - } + servers []*Server + serverless *rserverless.Serverless + watcher *grace.Watcher + lns map[string]net.Listener - run(mainConf, coreConf, options.Logger, pidFile) + pidfile string + log *zerolog.Logger } -type coreConf struct { - MaxCPUs string `mapstructure:"max_cpus"` - TracingEnabled bool `mapstructure:"tracing_enabled"` - TracingEndpoint string `mapstructure:"tracing_endpoint"` - TracingCollector string `mapstructure:"tracing_collector"` - TracingServiceName string `mapstructure:"tracing_service_name"` +// Server represents a reva server (grpc or http). +type Server struct { + server grace.Server + listener net.Listener - // TracingService specifies the service. i.e OpenCensus, OpenTelemetry, OpenTracing... - TracingService string `mapstructure:"tracing_service"` + services map[string]any } -func run(mainConf map[string]interface{}, coreConf *coreConf, logger *zerolog.Logger, filename string) { - host, _ := os.Hostname() - logger.Info().Msgf("host info: %s", host) +// Start starts the server listening on the assigned listener. +func (s *Server) Start() error { + return s.server.Start(s.listener) +} - if coreConf.TracingEnabled { - initTracing(coreConf) - } - initCPUCount(coreConf, logger) +// New creates a new reva instance. +func New(config *config.Config, opt ...Option) (*Reva, error) { + opts := newOptions(opt...) + log := opts.Logger - servers := initServers(mainConf, logger) - serverless := initServerless(mainConf, logger) + if err := initCPUCount(config.Core, log); err != nil { + return nil, err + } + initTracing(config.Core) - if len(servers) == 0 && serverless == nil { - logger.Info().Msg("nothing to do, no grpc/http/serverless enabled_services declared in config") - os.Exit(1) + if opts.PidFile == "" { + return nil, errors.New("pid file not provided") } - watcher, err := initWatcher(logger, filename) + watcher, err := initWatcher(opts.PidFile, log) if err != nil { - log.Panic(err) + return nil, err } - listeners := initListeners(watcher, servers, logger) - if serverless != nil { - watcher.SL = serverless + + listeners, err := watcher.GetListeners(servicesAddresses(config)) + if err != nil { + watcher.Clean() + return nil, err } - start(mainConf, servers, serverless, listeners, logger, watcher) -} + setRandomAddresses(config, listeners, log) + + if err := applyTemplates(config); err != nil { + watcher.Clean() + return nil, err + } + initSharedConf(config) -func initListeners(watcher *grace.Watcher, servers map[string]grace.Server, log *zerolog.Logger) map[string]net.Listener { - listeners, err := watcher.GetListeners(servers) + grpc := groupGRPCByAddress(config) + http := groupHTTPByAddress(config) + servers, err := newServers(grpc, http, listeners, log) if err != nil { - log.Error().Err(err).Msg("error getting sockets") - watcher.Exit(1) + watcher.Clean() + return nil, err } - return listeners -} -func initWatcher(log *zerolog.Logger, filename string) (*grace.Watcher, error) { - watcher, err := handlePIDFlag(log, filename) - // TODO(labkode): maybe pidfile can be created later on? like once a server is going to be created? + serverless, err := newServerless(config, log) if err != nil { - log.Error().Err(err).Msg("error creating grace watcher") - os.Exit(1) + watcher.Clean() + return nil, err } - return watcher, err + + return &Reva{ + config: config, + servers: servers, + serverless: serverless, + watcher: watcher, + lns: listeners, + pidfile: opts.PidFile, + log: log, + }, nil } -func initServers(mainConf map[string]interface{}, log *zerolog.Logger) map[string]grace.Server { - servers := map[string]grace.Server{} - if isEnabledHTTP(mainConf) { - s, err := getHTTPServer(mainConf["http"], log) - if err != nil { - log.Error().Err(err).Msg("error creating http server") - os.Exit(1) - } - servers["http"] = s - } +func servicesAddresses(cfg *config.Config) map[string]grace.Addressable { + a := make(map[string]grace.Addressable) + cfg.GRPC.ForEachService(func(s *config.Service) { + a[s.Label] = &addr{address: s.Address.String(), network: s.Network} + }) + cfg.HTTP.ForEachService(func(s *config.Service) { + a[s.Label] = &addr{address: s.Address.String(), network: s.Network} + }) + return a +} - if isEnabledGRPC(mainConf) { - s, err := getGRPCServer(mainConf["grpc"], log) +func newServerless(config *config.Config, log *zerolog.Logger) (*rserverless.Serverless, error) { + sl := make(map[string]rserverless.Service) + logger := log.With().Str("pkg", "serverless").Logger() + if err := config.Serverless.ForEach(func(name string, config map[string]any) error { + new, ok := rserverless.Services[name] + if !ok { + return fmt.Errorf("serverless service %s does not exist", name) + } + log := logger.With().Str("service", name).Logger() + svc, err := new(config, &log) if err != nil { - log.Error().Err(err).Msg("error creating grpc server") - os.Exit(1) + return errors.Wrapf(err, "serverless service %s could not be initialized", name) } - servers["grpc"] = s + sl[name] = svc + return nil + }); err != nil { + return nil, err } - return servers + ss, err := rserverless.New( + rserverless.WithLogger(&logger), + rserverless.WithServices(sl), + ) + if err != nil { + return nil, err + } + return ss, nil } -func initServerless(mainConf map[string]interface{}, log *zerolog.Logger) *rserverless.Serverless { - if isEnabledServerless(mainConf) { - serverless, err := getServerless(mainConf["serverless"], log) - if err != nil { - log.Error().Err(err).Msg("error") - os.Exit(1) +func setRandomAddresses(c *config.Config, lns map[string]net.Listener, log *zerolog.Logger) { + f := func(s *config.Service) { + if s.Address != "" { + return } - return serverless + ln, ok := lns[s.Label] + if !ok { + log.Fatal().Msg("port not assigned for service " + s.Label) + } + s.SetAddress(config.Address(ln.Addr().String())) + log.Debug(). + Msgf("set random address %s:%s to service %s", ln.Addr().Network(), ln.Addr().String(), s.Label) } - - return nil + c.GRPC.ForEachService(f) + c.HTTP.ForEachService(f) } -func initTracing(conf *coreConf) { - rtrace.SetTraceProvider(conf.TracingCollector, conf.TracingEndpoint, conf.TracingServiceName) +type addr struct { + address string + network string } -func initCPUCount(conf *coreConf, log *zerolog.Logger) { - ncpus, err := adjustCPU(conf.MaxCPUs) - if err != nil { - log.Error().Err(err).Msg("error adjusting number of cpus") - os.Exit(1) - } - // log.Info().Msgf("%s", getVersionString()) - log.Info().Msgf("running on %d cpus", ncpus) +func (a *addr) Address() string { + return a.address } -func initLogger(conf *logConf) *zerolog.Logger { - log, err := newLogger(conf) - if err != nil { - fmt.Fprintf(os.Stderr, "error creating logger, exiting ...") - os.Exit(1) - } - return log +func (a *addr) Network() string { + return a.network } -func handlePIDFlag(l *zerolog.Logger, pidFile string) (*grace.Watcher, error) { - var opts []grace.Option - opts = append(opts, grace.WithPIDFile(pidFile)) - opts = append(opts, grace.WithLogger(l.With().Str("pkg", "grace").Logger())) - w := grace.NewWatcher(opts...) - err := w.WritePID() - if err != nil { - return nil, err +func groupGRPCByAddress(cfg *config.Config) []*config.GRPC { + // TODO: same address cannot be used in different configurations + g := map[string]*config.GRPC{} + cfg.GRPC.ForEachService(func(s *config.Service) { + if _, ok := g[s.Address.String()]; !ok { + g[s.Address.String()] = &config.GRPC{ + Address: s.Address, + Network: s.Network, + ShutdownDeadline: cfg.GRPC.ShutdownDeadline, + EnableReflection: cfg.GRPC.EnableReflection, + Services: make(map[string]config.ServicesConfig), + Interceptors: cfg.GRPC.Interceptors, + } + } + g[s.Address.String()].Services[s.Name] = config.ServicesConfig{ + {Config: s.Config, Address: s.Address, Network: s.Network, Label: s.Label}, + } + }) + l := make([]*config.GRPC, 0, len(g)) + for _, c := range g { + l = append(l, c) } - - return w, nil + return l } -func start(mainConf map[string]interface{}, servers map[string]grace.Server, serverless *rserverless.Serverless, listeners map[string]net.Listener, log *zerolog.Logger, watcher *grace.Watcher) { - if isEnabledHTTP(mainConf) { - go func() { - if err := servers["http"].(*rhttp.Server).Start(listeners["http"]); err != nil { - log.Error().Err(err).Msg("error starting the http server") - watcher.Exit(1) +func groupHTTPByAddress(cfg *config.Config) []*config.HTTP { + g := map[string]*config.HTTP{} + cfg.HTTP.ForEachService(func(s *config.Service) { + if _, ok := g[s.Address.String()]; !ok { + g[s.Address.String()] = &config.HTTP{ + Address: s.Address, + Network: s.Network, + CertFile: cfg.HTTP.CertFile, + KeyFile: cfg.HTTP.KeyFile, + Services: make(map[string]config.ServicesConfig), + Middlewares: cfg.HTTP.Middlewares, } - }() - } - if isEnabledGRPC(mainConf) { - go func() { - if err := servers["grpc"].(*rgrpc.Server).Start(listeners["grpc"]); err != nil { - log.Error().Err(err).Msg("error starting the grpc server") - watcher.Exit(1) - } - }() - } - if isEnabledServerless(mainConf) { - if err := serverless.Start(); err != nil { - log.Error().Err(err).Msg("error starting serverless services") - watcher.Exit(1) } + g[s.Address.String()].Services[s.Name] = config.ServicesConfig{ + {Config: s.Config, Address: s.Address, Network: s.Network, Label: s.Label}, + } + }) + l := make([]*config.HTTP, 0, len(g)) + for _, c := range g { + l = append(l, c) } - - watcher.TrapSignals() + return l } -func newLogger(conf *logConf) (*zerolog.Logger, error) { - // TODO(labkode): use debug level rather than info as default until reaching a stable version. - // Helps having smaller development files. - if conf.Level == "" { - conf.Level = zerolog.DebugLevel.String() - } - - var opts []logger.Option - opts = append(opts, logger.WithLevel(conf.Level)) +// Start starts all the reva services and waits for a signal. +func (r *Reva) Start() error { + defer r.watcher.Clean() + r.watcher.SetServers(list.Map(r.servers, func(s *Server) grace.Server { return s.server })) + r.watcher.SetServerless(r.serverless) - w, err := getWriter(conf.Output) - if err != nil { - return nil, err + var g errgroup.Group + for _, server := range r.servers { + server := server + g.Go(func() error { + return server.Start() + }) } - opts = append(opts, logger.WithWriter(w, logger.Mode(conf.Mode))) + g.Go(func() error { + return r.serverless.Start() + }) - l := logger.New(opts...) - sub := l.With().Int("pid", os.Getpid()).Logger() - return &sub, nil + r.watcher.TrapSignals() + return g.Wait() } -func getWriter(out string) (io.Writer, error) { - if out == "stderr" || out == "" { - return os.Stderr, nil - } - - if out == "stdout" { - return os.Stdout, nil - } +func initSharedConf(config *config.Config) { + sharedconf.Init(config.Shared) +} - fd, err := os.OpenFile(out, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - err = errors.Wrap(err, "error creating log file: "+out) - return nil, err - } +func initWatcher(filename string, log *zerolog.Logger) (*grace.Watcher, error) { + return handlePIDFlag(log, filename) + // TODO(labkode): maybe pidfile can be created later on? like once a server is going to be created? +} - return fd, nil +func applyTemplates(config *config.Config) error { + return config.ApplyTemplates(config) } -func getGRPCServer(conf interface{}, l *zerolog.Logger) (*rgrpc.Server, error) { - sub := l.With().Str("pkg", "rgrpc").Logger() - s, err := rgrpc.NewServer(conf, sub) +func initCPUCount(conf *config.Core, log *zerolog.Logger) error { + ncpus, err := adjustCPU(conf.MaxCPUs) if err != nil { - err = errors.Wrap(err, "main: error creating grpc server") - return nil, err + return errors.Wrap(err, "error adjusting number of cpus") } - return s, nil + log.Info().Msgf("running on %d cpus", ncpus) + return nil } -func getHTTPServer(conf interface{}, l *zerolog.Logger) (*rhttp.Server, error) { - sub := l.With().Str("pkg", "rhttp").Logger() - s, err := rhttp.New(conf, sub) +func handlePIDFlag(l *zerolog.Logger, pidFile string) (*grace.Watcher, error) { + w := grace.NewWatcher( + grace.WithPIDFile(pidFile), + grace.WithLogger(l.With().Str("pkg", "grace").Logger()), + ) + err := w.WritePID() if err != nil { - err = errors.Wrap(err, "main: error creating http server") return nil, err } - return s, nil + + return w, nil } -func getServerless(conf interface{}, l *zerolog.Logger) (*rserverless.Serverless, error) { - sub := l.With().Str("pkg", "rserverless").Logger() - return rserverless.New(conf, sub) +func initTracing(conf *config.Core) { + if conf.TracingEnabled { + rtrace.SetTraceProvider(conf.TracingCollector, conf.TracingEndpoint, conf.TracingServiceName) + } } -// adjustCPU parses string cpu and sets GOMAXPROCS -// +// adjustCPU parses string cpu and sets GOMAXPROCS // according to its value. It accepts either // a number (e.g. 3) or a percent (e.g. 50%). // Default is to use all available cores. @@ -337,81 +348,77 @@ func adjustCPU(cpu string) (int, error) { return numCPU, nil } -func parseCoreConfOrDie(v interface{}) *coreConf { - c := &coreConf{} - if err := mapstructure.Decode(v, c); err != nil { - fmt.Fprintf(os.Stderr, "error decoding core config: %s\n", err.Error()) - os.Exit(1) - } - - // tracing defaults to enabled if not explicitly configured - if v == nil { - c.TracingEnabled = true - c.TracingEndpoint = "localhost:6831" - } else if _, ok := v.(map[string]interface{})["tracing_enabled"]; !ok { - c.TracingEnabled = true - c.TracingEndpoint = "localhost:6831" - } - - return c -} - -func parseSharedConfOrDie(v interface{}) { - if err := sharedconf.Decode(v); err != nil { - fmt.Fprintf(os.Stderr, "error decoding shared config: %s\n", err.Error()) - os.Exit(1) - } -} - -func parseLogConfOrDie(v interface{}, logLevel string) *logConf { - c := &logConf{} - if err := mapstructure.Decode(v, c); err != nil { - fmt.Fprintf(os.Stderr, "error decoding log config: %s\n", err.Error()) - os.Exit(1) - } - - // if mode is not set, we use console mode, easier for devs - if c.Mode == "" { - c.Mode = "console" - } - - // Give priority to the log level passed through the command line. - if logLevel != "" { - c.Level = logLevel +func listenerFromAddress(lns map[string]net.Listener, network string, address config.Address) net.Listener { + for _, ln := range lns { + if netutil.AddressEqual(ln.Addr(), network, address.String()) { + return ln + } } - - return c -} - -type logConf struct { - Output string `mapstructure:"output"` - Mode string `mapstructure:"mode"` - Level string `mapstructure:"level"` -} - -func isEnabledHTTP(conf map[string]interface{}) bool { - return isEnabled("http", conf) -} - -func isEnabledGRPC(conf map[string]interface{}) bool { - return isEnabled("grpc", conf) -} - -func isEnabledServerless(conf map[string]interface{}) bool { - return isEnabled("serverless", conf) + panic(fmt.Sprintf("listener not found for address %s:%s", network, address)) } -func isEnabled(key string, conf map[string]interface{}) bool { - if a, ok := conf[key]; ok { - if b, ok := a.(map[string]interface{}); ok { - if c, ok := b["services"]; ok { - if d, ok := c.(map[string]interface{}); ok { - if len(d) > 0 { - return true - } - } - } +func newServers(grpc []*config.GRPC, http []*config.HTTP, lns map[string]net.Listener, log *zerolog.Logger) ([]*Server, error) { + servers := make([]*Server, 0, len(grpc)+len(http)) + for _, cfg := range grpc { + services, err := rgrpc.InitServices(cfg.Services) + if err != nil { + return nil, err + } + unaryChain, streamChain, err := initGRPCInterceptors(cfg.Interceptors, grpcUnprotected(services), log) + if err != nil { + return nil, err + } + s, err := rgrpc.NewServer( + rgrpc.EnableReflection(cfg.EnableReflection), + rgrpc.WithShutdownDeadline(cfg.ShutdownDeadline), + rgrpc.WithLogger(log.With().Str("pkg", "grpc").Logger()), + rgrpc.WithServices(services), + rgrpc.WithUnaryServerInterceptors(unaryChain), + rgrpc.WithStreamServerInterceptors(streamChain), + ) + if err != nil { + return nil, err + } + ln := listenerFromAddress(lns, cfg.Network, cfg.Address) + server := &Server{ + server: s, + listener: ln, + services: maps.MapValues(services, func(s rgrpc.Service) any { return s }), + } + log.Debug(). + Interface("services", maps.Keys(cfg.Services)). + Msgf("spawned grpc server for services listening at %s:%s", ln.Addr().Network(), ln.Addr().String()) + servers = append(servers, server) + } + for _, cfg := range http { + log := log.With().Str("pkg", "http").Logger() + services, err := rhttp.InitServices(cfg.Services, &log) + if err != nil { + return nil, err + } + middlewares, err := initHTTPMiddlewares(cfg.Middlewares, httpUnprotected(services), &log) + if err != nil { + return nil, err + } + s, err := rhttp.New( + rhttp.WithServices(services), + rhttp.WithLogger(log), + rhttp.WithCertAndKeyFiles(cfg.CertFile, cfg.KeyFile), + rhttp.WithMiddlewares(middlewares), + ) + if err != nil { + return nil, err + } + ln := listenerFromAddress(lns, cfg.Network, cfg.Address) + server := &Server{ + server: s, + listener: ln, + services: maps.MapValues(services, func(s global.Service) any { return s }), } + log.Debug(). + Interface("services", maps.Keys(cfg.Services)). + Msgf("spawned http server for services listening at %s:%s", ln.Addr().Network(), ln.Addr().String()) + servers = append(servers, server) } - return false + return servers, nil } diff --git a/go.mod b/go.mod index a1a77de9d9..fc47fc8e7c 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( ) require ( + github.com/creasty/defaults v1.7.0 // indirect github.com/go-jose/go-jose/v3 v3.0.0 // indirect github.com/hashicorp/go-msgpack/v2 v2.1.0 // indirect ) diff --git a/go.sum b/go.sum index 213016d950..5400bd80c6 100644 --- a/go.sum +++ b/go.sum @@ -304,6 +304,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsr github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creasty/defaults v1.7.0 h1:eNdqZvc5B509z18lD8yc212CAqJNvfT1Jq6L8WowdBA= +github.com/creasty/defaults v1.7.0/go.mod h1:iGzKe6pbEHnpMPtfDXZEr0NVxWnPTjb1bbDy08fPzYM= github.com/cs3org/cato v0.0.0-20200828125504-e418fc54dd5e h1:tqSPWQeueWTKnJVMJffz4pz0o1WuQxJ28+5x5JgaHD8= github.com/cs3org/cato v0.0.0-20200828125504-e418fc54dd5e/go.mod h1:XJEZ3/EQuI3BXTp/6DUzFr850vlxq11I6satRtz0YQ4= github.com/cs3org/go-cs3apis v0.0.0-20230727093620-0f4399be4543 h1:IFo6dj0XEOIA6i2baRWMC3vd+fAmuIUAVfSf77ZhoQg= diff --git a/internal/grpc/services/applicationauth/applicationauth.go b/internal/grpc/services/applicationauth/applicationauth.go index 65514b6dcd..47f4c24680 100644 --- a/internal/grpc/services/applicationauth/applicationauth.go +++ b/internal/grpc/services/applicationauth/applicationauth.go @@ -73,7 +73,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a app auth provider svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/appprovider/appprovider.go b/internal/grpc/services/appprovider/appprovider.go index 2a743cc5b4..be1365885e 100644 --- a/internal/grpc/services/appprovider/appprovider.go +++ b/internal/grpc/services/appprovider/appprovider.go @@ -82,7 +82,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new AppProviderService. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/appregistry/appregistry.go b/internal/grpc/services/appregistry/appregistry.go index 619d366942..d70b9e7746 100644 --- a/internal/grpc/services/appregistry/appregistry.go +++ b/internal/grpc/services/appregistry/appregistry.go @@ -63,7 +63,7 @@ func (c *config) init() { } // New creates a new StorageRegistryService. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/appregistry/appregistry_test.go b/internal/grpc/services/appregistry/appregistry_test.go index 0bf23f8fa3..34b8c5d3d1 100644 --- a/internal/grpc/services/appregistry/appregistry_test.go +++ b/internal/grpc/services/appregistry/appregistry_test.go @@ -355,7 +355,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := New(tt.m, nil) + got, err := New(tt.m) if err != nil { assert.Equal(t, tt.wantErr, err.Error()) assert.Nil(t, got) diff --git a/internal/grpc/services/authprovider/authprovider.go b/internal/grpc/services/authprovider/authprovider.go index c080d8135f..135af03745 100644 --- a/internal/grpc/services/authprovider/authprovider.go +++ b/internal/grpc/services/authprovider/authprovider.go @@ -100,7 +100,7 @@ func getAuthManager(manager string, m map[string]map[string]interface{}) (auth.M } // New returns a new AuthProviderServiceServer. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/authregistry/authregistry.go b/internal/grpc/services/authregistry/authregistry.go index 0abbe63ad5..dcbbd1b471 100644 --- a/internal/grpc/services/authregistry/authregistry.go +++ b/internal/grpc/services/authregistry/authregistry.go @@ -66,7 +66,7 @@ func (c *config) init() { } // New creates a new AuthRegistry. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/datatx/datatx.go b/internal/grpc/services/datatx/datatx.go index b9f5011a3d..ce82fa5c36 100644 --- a/internal/grpc/services/datatx/datatx.go +++ b/internal/grpc/services/datatx/datatx.go @@ -88,7 +88,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new datatx svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/gateway/gateway.go b/internal/grpc/services/gateway/gateway.go index 72776e5fb5..445ae7b13a 100644 --- a/internal/grpc/services/gateway/gateway.go +++ b/internal/grpc/services/gateway/gateway.go @@ -124,7 +124,7 @@ type svc struct { // New creates a new gateway svc that acts as a proxy for any grpc operation. // The gateway is responsible for high-level controls: rate-limiting, coordination between svcs // like sharing and storage acls, asynchronous transactions, ... -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/groupprovider/groupprovider.go b/internal/grpc/services/groupprovider/groupprovider.go index 1294e3892e..7e23c2f97b 100644 --- a/internal/grpc/services/groupprovider/groupprovider.go +++ b/internal/grpc/services/groupprovider/groupprovider.go @@ -68,7 +68,7 @@ func getDriver(c *config) (group.Manager, error) { } // New returns a new GroupProviderServiceServer. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/helloworld/helloworld.go b/internal/grpc/services/helloworld/helloworld.go index f0e9c311be..40aecaa356 100644 --- a/internal/grpc/services/helloworld/helloworld.go +++ b/internal/grpc/services/helloworld/helloworld.go @@ -43,7 +43,7 @@ type service struct { // New returns a new PreferencesServiceServer // It can be tested like this: // prototool grpc --address 0.0.0.0:9999 --method 'revad.helloworld.HelloWorldService/Hello' --data '{"name": "Alice"}'. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c := &conf{} if err := mapstructure.Decode(m, c); err != nil { err = errors.Wrap(err, "helloworld: error decoding conf") diff --git a/internal/grpc/services/ocmcore/ocmcore.go b/internal/grpc/services/ocmcore/ocmcore.go index ebb0d9b9f6..16686bc9fd 100644 --- a/internal/grpc/services/ocmcore/ocmcore.go +++ b/internal/grpc/services/ocmcore/ocmcore.go @@ -78,7 +78,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new ocm core svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/ocminvitemanager/ocminvitemanager.go b/internal/grpc/services/ocminvitemanager/ocminvitemanager.go index 80d8ff1d01..38cf922c5a 100644 --- a/internal/grpc/services/ocminvitemanager/ocminvitemanager.go +++ b/internal/grpc/services/ocminvitemanager/ocminvitemanager.go @@ -103,7 +103,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new OCM invite manager svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/ocmproviderauthorizer/ocmproviderauthorizer.go b/internal/grpc/services/ocmproviderauthorizer/ocmproviderauthorizer.go index 9d0892068d..e14645fe5a 100644 --- a/internal/grpc/services/ocmproviderauthorizer/ocmproviderauthorizer.go +++ b/internal/grpc/services/ocmproviderauthorizer/ocmproviderauthorizer.go @@ -74,7 +74,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new OCM provider authorizer svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/ocmshareprovider/ocmshareprovider.go b/internal/grpc/services/ocmshareprovider/ocmshareprovider.go index 9051f1dba2..c8f39962f0 100644 --- a/internal/grpc/services/ocmshareprovider/ocmshareprovider.go +++ b/internal/grpc/services/ocmshareprovider/ocmshareprovider.go @@ -110,7 +110,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new ocm share provider svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/permissions/permissions.go b/internal/grpc/services/permissions/permissions.go index 490d4c78fa..428e5457f1 100644 --- a/internal/grpc/services/permissions/permissions.go +++ b/internal/grpc/services/permissions/permissions.go @@ -55,7 +55,7 @@ type service struct { } // New returns a new PermissionsServiceServer. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/preferences/preferences.go b/internal/grpc/services/preferences/preferences.go index 25c814b1cc..f2b096cbf7 100644 --- a/internal/grpc/services/preferences/preferences.go +++ b/internal/grpc/services/preferences/preferences.go @@ -69,7 +69,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New returns a new PreferencesServiceServer. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/publicshareprovider/publicshareprovider.go b/internal/grpc/services/publicshareprovider/publicshareprovider.go index 5829b050d8..1d513e5554 100644 --- a/internal/grpc/services/publicshareprovider/publicshareprovider.go +++ b/internal/grpc/services/publicshareprovider/publicshareprovider.go @@ -87,7 +87,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new user share provider svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/publicstorageprovider/publicstorageprovider.go b/internal/grpc/services/publicstorageprovider/publicstorageprovider.go index bf913e4136..43c141e6bf 100644 --- a/internal/grpc/services/publicstorageprovider/publicstorageprovider.go +++ b/internal/grpc/services/publicstorageprovider/publicstorageprovider.go @@ -83,7 +83,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new IsPublic Storage Provider service. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/storageprovider/storageprovider.go b/internal/grpc/services/storageprovider/storageprovider.go index 572cf2bdd0..9bb1011817 100644 --- a/internal/grpc/services/storageprovider/storageprovider.go +++ b/internal/grpc/services/storageprovider/storageprovider.go @@ -162,7 +162,7 @@ func registerMimeTypes(mappingFile string) error { } // New creates a new storage provider svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/storageregistry/storageregistry.go b/internal/grpc/services/storageregistry/storageregistry.go index b7304ba571..685a8c0004 100644 --- a/internal/grpc/services/storageregistry/storageregistry.go +++ b/internal/grpc/services/storageregistry/storageregistry.go @@ -64,7 +64,7 @@ func (c *config) init() { } // New creates a new StorageBrokerService. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/userprovider/userprovider.go b/internal/grpc/services/userprovider/userprovider.go index 27907db225..fa6cb78abc 100644 --- a/internal/grpc/services/userprovider/userprovider.go +++ b/internal/grpc/services/userprovider/userprovider.go @@ -87,7 +87,7 @@ func getDriver(c *config) (user.Manager, *plugin.RevaPlugin, error) { } // New returns a new UserProviderServiceServer. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/internal/grpc/services/usershareprovider/usershareprovider.go b/internal/grpc/services/usershareprovider/usershareprovider.go index 523a7b6c1f..074397364e 100644 --- a/internal/grpc/services/usershareprovider/usershareprovider.go +++ b/internal/grpc/services/usershareprovider/usershareprovider.go @@ -89,7 +89,7 @@ func parseConfig(m map[string]interface{}) (*config, error) { } // New creates a new user share provider svc. -func New(m map[string]interface{}, ss *grpc.Server) (rgrpc.Service, error) { +func New(m map[string]interface{}) (rgrpc.Service, error) { c, err := parseConfig(m) if err != nil { return nil, err diff --git a/pkg/rgrpc/option.go b/pkg/rgrpc/option.go new file mode 100644 index 0000000000..b59492dda0 --- /dev/null +++ b/pkg/rgrpc/option.go @@ -0,0 +1,62 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package rgrpc + +import ( + "github.com/rs/zerolog" + "google.golang.org/grpc" +) + +type Option func(*Server) + +func WithShutdownDeadline(deadline int) Option { + return func(s *Server) { + s.ShutdownDeadline = deadline + } +} + +func EnableReflection(enable bool) Option { + return func(s *Server) { + s.EnableReflection = enable + } +} + +func WithServices(services map[string]Service) Option { + return func(s *Server) { + s.services = services + } +} + +func WithLogger(logger zerolog.Logger) Option { + return func(s *Server) { + s.log = logger + } +} + +func WithStreamServerInterceptors(in []grpc.StreamServerInterceptor) Option { + return func(s *Server) { + s.StreamServerInterceptors = in + } +} + +func WithUnaryServerInterceptors(in []grpc.UnaryServerInterceptor) Option { + return func(s *Server) { + s.UnaryServerInterceptors = in + } +} diff --git a/pkg/rgrpc/rgrpc.go b/pkg/rgrpc/rgrpc.go index bb363b276e..5508bcb36a 100644 --- a/pkg/rgrpc/rgrpc.go +++ b/pkg/rgrpc/rgrpc.go @@ -22,21 +22,11 @@ import ( "fmt" "io" "net" - "sort" - - "github.com/cs3org/reva/internal/grpc/interceptors/appctx" - "github.com/cs3org/reva/internal/grpc/interceptors/auth" - "github.com/cs3org/reva/internal/grpc/interceptors/log" - "github.com/cs3org/reva/internal/grpc/interceptors/recovery" - "github.com/cs3org/reva/internal/grpc/interceptors/token" - "github.com/cs3org/reva/internal/grpc/interceptors/useragent" - "github.com/cs3org/reva/pkg/sharedconf" - rtrace "github.com/cs3org/reva/pkg/trace" + + "github.com/cs3org/reva/cmd/revad/pkg/config" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/mitchellh/mapstructure" "github.com/pkg/errors" "github.com/rs/zerolog" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/reflection" ) @@ -73,7 +63,7 @@ func Register(name string, newFunc NewService) { // NewService is the function that gRPC services need to register at init time. // It returns an io.Closer to close the service and a list of service endpoints that need to be unprotected. -type NewService func(conf map[string]interface{}, ss *grpc.Server) (Service, error) +type NewService func(conf map[string]interface{}) (Service, error) // Service represents a grpc service. type Service interface { @@ -82,63 +72,51 @@ type Service interface { UnprotectedEndpoints() []string } -type unaryInterceptorTriple struct { - Name string - Priority int - Interceptor grpc.UnaryServerInterceptor -} - -type streamInterceptorTriple struct { - Name string - Priority int - Interceptor grpc.StreamServerInterceptor -} - -type config struct { - Network string `mapstructure:"network"` - Address string `mapstructure:"address"` - ShutdownDeadline int `mapstructure:"shutdown_deadline"` - Services map[string]map[string]interface{} `mapstructure:"services"` - Interceptors map[string]map[string]interface{} `mapstructure:"interceptors"` - EnableReflection bool `mapstructure:"enable_reflection"` -} - -func (c *config) init() { - if c.Network == "" { - c.Network = "tcp" - } - - if c.Address == "" { - c.Address = sharedconf.GetGatewaySVC("0.0.0.0:19000") - } -} - // Server is a gRPC server. type Server struct { + ShutdownDeadline int + EnableReflection bool + UnaryServerInterceptors []grpc.UnaryServerInterceptor + StreamServerInterceptors []grpc.StreamServerInterceptor + s *grpc.Server - conf *config listener net.Listener log zerolog.Logger services map[string]Service } -// NewServer returns a new Server. -func NewServer(m interface{}, log zerolog.Logger) (*Server, error) { - conf := &config{} - if err := mapstructure.Decode(m, conf); err != nil { - return nil, err +func InitServices(services map[string]config.ServicesConfig) (map[string]Service, error) { + s := make(map[string]Service) + for name, cfg := range services { + new, ok := Services[name] + if !ok { + return nil, fmt.Errorf("rgrpc: grpc service %s does not exist", name) + } + if cfg.DriversNumber() > 1 { + return nil, fmt.Errorf("rgrp: service %s cannot have more than one driver in same server", name) + } + svc, err := new(cfg[0].Config) + if err != nil { + return nil, errors.Wrapf(err, "rgrpc: grpc service %s could not be started,", name) + } + s[name] = svc } + return s, nil +} - conf.init() - - server := &Server{conf: conf, log: log, services: map[string]Service{}} +// NewServer returns a new Server. +func NewServer(o ...Option) (*Server, error) { + server := &Server{} + for _, oo := range o { + oo(server) + } return server, nil } // Start starts the server. func (s *Server) Start(ln net.Listener) error { - if err := s.registerServices(); err != nil { + if err := s.initServices(); err != nil { err = errors.Wrap(err, "unable to register services") return err } @@ -153,57 +131,15 @@ func (s *Server) Start(ln net.Listener) error { return nil } -func (s *Server) isInterceptorEnabled(name string) bool { - for k := range s.conf.Interceptors { - if k == name { - return true - } - } - return false -} - -func (s *Server) isServiceEnabled(svcName string) bool { - for key := range Services { - if key == svcName { - return true - } - } - return false -} - -func (s *Server) registerServices() error { - for svcName := range s.conf.Services { - if s.isServiceEnabled(svcName) { - newFunc := Services[svcName] - svc, err := newFunc(s.conf.Services[svcName], s.s) - if err != nil { - return errors.Wrapf(err, "rgrpc: grpc service %s could not be started,", svcName) - } - s.services[svcName] = svc - s.log.Info().Msgf("rgrpc: grpc service enabled: %s", svcName) - } else { - message := fmt.Sprintf("rgrpc: grpc service %s does not exist", svcName) - return errors.New(message) - } - } - - // obtain list of unprotected endpoints - unprotected := []string{} - for _, svc := range s.services { - unprotected = append(unprotected, svc.UnprotectedEndpoints()...) - } - - opts, err := s.getInterceptors(unprotected) - if err != nil { - return err - } +func (s *Server) initServices() error { + opts := s.getInterceptors() grpcServer := grpc.NewServer(opts...) for _, svc := range s.services { svc.Register(grpcServer) } - if s.conf.EnableReflection { + if s.EnableReflection { s.log.Info().Msg("rgrpc: grpc server reflection enabled") reflection.Register(grpcServer) } @@ -240,109 +176,20 @@ func (s *Server) GracefulStop() error { // Network returns the network type. func (s *Server) Network() string { - return s.conf.Network + return s.listener.Addr().Network() } // Address returns the network address. func (s *Server) Address() string { - return s.conf.Address + return s.listener.Addr().String() } -func (s *Server) getInterceptors(unprotected []string) ([]grpc.ServerOption, error) { - unaryTriples := []*unaryInterceptorTriple{} - for name, newFunc := range UnaryInterceptors { - if s.isInterceptorEnabled(name) { - inter, prio, err := newFunc(s.conf.Interceptors[name]) - if err != nil { - err = errors.Wrapf(err, "rgrpc: error creating unary interceptor: %s,", name) - return nil, err - } - triple := &unaryInterceptorTriple{ - Name: name, - Priority: prio, - Interceptor: inter, - } - unaryTriples = append(unaryTriples, triple) - } - } - - // sort unary triples - sort.SliceStable(unaryTriples, func(i, j int) bool { - return unaryTriples[i].Priority < unaryTriples[j].Priority - }) +func (s *Server) getInterceptors() []grpc.ServerOption { + unaryChain := grpc_middleware.ChainUnaryServer(s.UnaryServerInterceptors...) + streamChain := grpc_middleware.ChainStreamServer(s.StreamServerInterceptors...) - authUnary, err := auth.NewUnary(s.conf.Interceptors["auth"], unprotected) - if err != nil { - return nil, errors.Wrap(err, "rgrpc: error creating unary auth interceptor") - } - - unaryInterceptors := []grpc.UnaryServerInterceptor{authUnary} - for _, t := range unaryTriples { - unaryInterceptors = append(unaryInterceptors, t.Interceptor) - s.log.Info().Msgf("rgrpc: chaining grpc unary interceptor %s with priority %d", t.Name, t.Priority) - } - - unaryInterceptors = append(unaryInterceptors, - otelgrpc.UnaryServerInterceptor( - otelgrpc.WithTracerProvider(rtrace.Provider), - otelgrpc.WithPropagators(rtrace.Propagator)), - ) - - unaryInterceptors = append([]grpc.UnaryServerInterceptor{ - appctx.NewUnary(s.log), - token.NewUnary(), - useragent.NewUnary(), - log.NewUnary(), - recovery.NewUnary(), - }, unaryInterceptors...) - unaryChain := grpc_middleware.ChainUnaryServer(unaryInterceptors...) - - streamTriples := []*streamInterceptorTriple{} - for name, newFunc := range StreamInterceptors { - if s.isInterceptorEnabled(name) { - inter, prio, err := newFunc(s.conf.Interceptors[name]) - if err != nil { - err = errors.Wrapf(err, "rgrpc: error creating streaming interceptor: %s,", name) - return nil, err - } - triple := &streamInterceptorTriple{ - Name: name, - Priority: prio, - Interceptor: inter, - } - streamTriples = append(streamTriples, triple) - } - } - // sort stream triples - sort.SliceStable(streamTriples, func(i, j int) bool { - return streamTriples[i].Priority < streamTriples[j].Priority - }) - - authStream, err := auth.NewStream(s.conf.Interceptors["auth"], unprotected) - if err != nil { - return nil, errors.Wrap(err, "rgrpc: error creating stream auth interceptor") - } - - streamInterceptors := []grpc.StreamServerInterceptor{authStream} - for _, t := range streamTriples { - streamInterceptors = append(streamInterceptors, t.Interceptor) - s.log.Info().Msgf("rgrpc: chaining grpc streaming interceptor %s with priority %d", t.Name, t.Priority) - } - - streamInterceptors = append([]grpc.StreamServerInterceptor{ - authStream, - appctx.NewStream(s.log), - token.NewStream(), - useragent.NewStream(), - log.NewStream(), - recovery.NewStream(), - }, streamInterceptors...) - streamChain := grpc_middleware.ChainStreamServer(streamInterceptors...) - - opts := []grpc.ServerOption{ + return []grpc.ServerOption{ grpc.UnaryInterceptor(unaryChain), grpc.StreamInterceptor(streamChain), } - - return opts, nil } diff --git a/pkg/rhttp/rhttp.go b/pkg/rhttp/rhttp.go index de000aaac1..e5d796f917 100644 --- a/pkg/rhttp/rhttp.go +++ b/pkg/rhttp/rhttp.go @@ -24,85 +24,99 @@ import ( "net" "net/http" "path" - "sort" "strings" "time" - "github.com/cs3org/reva/internal/http/interceptors/appctx" - "github.com/cs3org/reva/internal/http/interceptors/auth" - "github.com/cs3org/reva/internal/http/interceptors/log" - "github.com/cs3org/reva/internal/http/interceptors/providerauthorizer" + "github.com/cs3org/reva/cmd/revad/pkg/config" "github.com/cs3org/reva/pkg/rhttp/global" rtrace "github.com/cs3org/reva/pkg/trace" - "github.com/mitchellh/mapstructure" "github.com/pkg/errors" "github.com/rs/zerolog" "go.opentelemetry.io/otel/propagation" ) -// New returns a new server. -func New(m interface{}, l zerolog.Logger) (*Server, error) { - conf := &config{} - if err := mapstructure.Decode(m, conf); err != nil { - return nil, err +type Config func(*Server) + +func WithServices(services map[string]global.Service) Config { + return func(s *Server) { + s.Services = services + } +} + +func WithMiddlewares(middlewares []global.Middleware) Config { + return func(s *Server) { + s.middlewares = middlewares + } +} + +func WithCertAndKeyFiles(cert, key string) Config { + return func(s *Server) { + s.CertFile = cert + s.KeyFile = key } +} - conf.init() +func WithLogger(log zerolog.Logger) Config { + return func(s *Server) { + s.log = log + } +} +func InitServices(services map[string]config.ServicesConfig, log *zerolog.Logger) (map[string]global.Service, error) { + s := make(map[string]global.Service) + for name, cfg := range services { + new, ok := global.Services[name] + if !ok { + return nil, fmt.Errorf("http service %s does not exist", name) + } + if cfg.DriversNumber() > 1 { + return nil, fmt.Errorf("service %s cannot have more than one driver in the same server", name) + } + log := log.With().Str("service", name).Logger() + svc, err := new(cfg[0].Config, &log) + if err != nil { + return nil, errors.Wrapf(err, "http service %s could not be started", name) + } + s[name] = svc + } + return s, nil +} + +// New returns a new server. +func New(c ...Config) (*Server, error) { httpServer := &http.Server{} s := &Server{ + log: zerolog.Nop(), httpServer: httpServer, - conf: conf, svcs: map[string]global.Service{}, unprotected: []string{}, handlers: map[string]http.Handler{}, - log: l, + middlewares: []global.Middleware{}, + } + for _, cc := range c { + cc(s) } + s.registerServices() return s, nil } // Server contains the server info. type Server struct { + Services map[string]global.Service // map key is service name + CertFile string + KeyFile string + httpServer *http.Server - conf *config listener net.Listener svcs map[string]global.Service // map key is svc Prefix unprotected []string handlers map[string]http.Handler - middlewares []*middlewareTriple + middlewares []global.Middleware log zerolog.Logger } -type config struct { - Network string `mapstructure:"network"` - Address string `mapstructure:"address"` - Services map[string]map[string]interface{} `mapstructure:"services"` - Middlewares map[string]map[string]interface{} `mapstructure:"middlewares"` - CertFile string `mapstructure:"certfile"` - KeyFile string `mapstructure:"keyfile"` -} - -func (c *config) init() { - // apply defaults - if c.Network == "" { - c.Network = "tcp" - } - - if c.Address == "" { - c.Address = "0.0.0.0:19001" - } -} - // Start starts the server. func (s *Server) Start(ln net.Listener) error { - if err := s.registerServices(); err != nil { - return err - } - - if err := s.registerMiddlewares(); err != nil { - return err - } - handler, err := s.getHandler() if err != nil { return errors.Wrap(err, "rhttp: error creating http handler") @@ -111,11 +125,11 @@ func (s *Server) Start(ln net.Listener) error { s.httpServer.Handler = handler s.listener = ln - if (s.conf.CertFile != "") && (s.conf.KeyFile != "") { - s.log.Info().Msgf("https server listening at https://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile) - err = s.httpServer.ServeTLS(s.listener, s.conf.CertFile, s.conf.KeyFile) + if (s.CertFile != "") && (s.KeyFile != "") { + s.log.Info().Msgf("https server listening at https://%s using cert file '%s' and key file '%s'", s.listener.Addr(), s.CertFile, s.KeyFile) + err = s.httpServer.ServeTLS(s.listener, s.CertFile, s.KeyFile) } else { - s.log.Info().Msgf("http server listening at http://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile) + s.log.Info().Msgf("http server listening at http://%s", s.listener.Addr()) err = s.httpServer.Serve(s.listener) } if err == nil || err == http.ErrServerClosed { @@ -148,12 +162,12 @@ func (s *Server) closeServices() { // Network return the network type. func (s *Server) Network() string { - return s.conf.Network + return s.listener.Addr().Network() } // Address returns the network address. func (s *Server) Address() string { - return s.conf.Address + return s.listener.Addr().String() } // GracefulStop gracefully stops the server. @@ -162,68 +176,15 @@ func (s *Server) GracefulStop() error { return s.httpServer.Shutdown(context.Background()) } -// middlewareTriple represents a middleware with the -// priority to be chained. -type middlewareTriple struct { - Name string - Priority int - Middleware global.Middleware -} - -func (s *Server) registerMiddlewares() error { - middlewares := []*middlewareTriple{} - for name, newFunc := range global.NewMiddlewares { - if s.isMiddlewareEnabled(name) { - m, prio, err := newFunc(s.conf.Middlewares[name]) - if err != nil { - err = errors.Wrapf(err, "error creating new middleware: %s,", name) - return err - } - middlewares = append(middlewares, &middlewareTriple{ - Name: name, - Priority: prio, - Middleware: m, - }) - s.log.Info().Msgf("http middleware enabled: %s", name) - } - } - s.middlewares = middlewares - return nil -} - -func (s *Server) isMiddlewareEnabled(name string) bool { - _, ok := s.conf.Middlewares[name] - return ok -} - -func (s *Server) registerServices() error { - for svcName := range s.conf.Services { - if s.isServiceEnabled(svcName) { - newFunc := global.Services[svcName] - svcLogger := s.log.With().Str("service", svcName).Logger() - svc, err := newFunc(s.conf.Services[svcName], &svcLogger) - if err != nil { - err = errors.Wrapf(err, "http service %s could not be started,", svcName) - return err - } - - // instrument services with opencensus tracing. - h := traceHandler(svcName, svc.Handler()) - s.handlers[svc.Prefix()] = h - s.svcs[svc.Prefix()] = svc - s.unprotected = append(s.unprotected, getUnprotected(svc.Prefix(), svc.Unprotected())...) - s.log.Info().Msgf("http service enabled: %s@/%s", svcName, svc.Prefix()) - } else { - message := fmt.Sprintf("http service %s does not exist", svcName) - return errors.New(message) - } +func (s *Server) registerServices() { + for name, svc := range s.Services { + // instrument services with opencensus tracing. + h := traceHandler(name, svc.Handler()) + s.handlers[svc.Prefix()] = h + s.svcs[svc.Prefix()] = svc + s.unprotected = append(s.unprotected, getUnprotected(svc.Prefix(), svc.Unprotected())...) + s.log.Info().Msgf("http service enabled: %s@/%s", name, svc.Prefix()) } - return nil -} - -func (s *Server) isServiceEnabled(svcName string) bool { - _, ok := global.Services[svcName] - return ok } // TODO(labkode): if the http server is exposed under a basename we need to prepend @@ -312,44 +273,9 @@ func (s *Server) getHandler() (http.Handler, error) { w.WriteHeader(http.StatusNotFound) }) - // sort middlewares by priority. - sort.SliceStable(s.middlewares, func(i, j int) bool { - return s.middlewares[i].Priority > s.middlewares[j].Priority - }) - handler := http.Handler(h) - - for _, triple := range s.middlewares { - s.log.Info().Msgf("chaining http middleware %s with priority %d", triple.Name, triple.Priority) - handler = triple.Middleware(traceHandler(triple.Name, handler)) - } - - for _, v := range s.unprotected { - s.log.Info().Msgf("unprotected URL: %s", v) - } - authMiddle, err := auth.New(s.conf.Middlewares["auth"], s.unprotected) - if err != nil { - return nil, errors.Wrap(err, "rhttp: error creating auth middleware") - } - - // add always the logctx middleware as most priority, this middleware is internal - // and cannot be configured from the configuration. - coreMiddlewares := []*middlewareTriple{} - - providerAuthMiddle, err := addProviderAuthMiddleware(s.conf, s.unprotected) - if err != nil { - return nil, errors.Wrap(err, "rhttp: error creating providerauthorizer middleware") - } - if providerAuthMiddle != nil { - coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: providerAuthMiddle, Name: "providerauthorizer"}) - } - - coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: authMiddle, Name: "auth"}) - coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: log.New(), Name: "log"}) - coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: appctx.New(s.log), Name: "appctx"}) - - for _, triple := range coreMiddlewares { - handler = triple.Middleware(traceHandler(triple.Name, handler)) + for _, m := range s.middlewares { + handler = m(handler) } return handler, nil @@ -366,13 +292,3 @@ func traceHandler(name string, h http.Handler) http.Handler { h.ServeHTTP(w, r.WithContext(ctx)) }) } - -func addProviderAuthMiddleware(conf *config, unprotected []string) (global.Middleware, error) { - _, ocmdRegistered := global.Services["ocmd"] - _, ocmdEnabled := conf.Services["ocmd"] - ocmdPrefix, _ := conf.Services["ocmd"]["prefix"].(string) - if ocmdRegistered && ocmdEnabled { - return providerauthorizer.New(conf.Middlewares["providerauthorizer"], unprotected, ocmdPrefix) - } - return nil, nil -} diff --git a/cmd/revad/internal/config/config.go b/pkg/rserverless/option.go similarity index 61% rename from cmd/revad/internal/config/config.go rename to pkg/rserverless/option.go index 9ece6f8c1e..937b1eaa40 100644 --- a/cmd/revad/internal/config/config.go +++ b/pkg/rserverless/option.go @@ -16,29 +16,20 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. -package config +package rserverless -import ( - "io" +import "github.com/rs/zerolog" - "github.com/BurntSushi/toml" - "github.com/pkg/errors" -) +type Option func(*Serverless) -// Read reads the configuration from the reader. -func Read(r io.Reader) (map[string]interface{}, error) { - data, err := io.ReadAll(r) - if err != nil { - err = errors.Wrap(err, "config: error reading from reader") - return nil, err +func WithLogger(log *zerolog.Logger) Option { + return func(s *Serverless) { + s.log = log } +} - v := map[string]interface{}{} - err = toml.Unmarshal(data, &v) - if err != nil { - err = errors.Wrap(err, "config: error decoding toml data") - return nil, err +func WithServices(svc map[string]Service) Option { + return func(s *Serverless) { + s.Services = svc } - - return v, nil } diff --git a/pkg/rserverless/rserverless.go b/pkg/rserverless/rserverless.go index 7af26640c5..c68f1d6142 100644 --- a/pkg/rserverless/rserverless.go +++ b/pkg/rserverless/rserverless.go @@ -20,12 +20,9 @@ package rserverless import ( "context" - "fmt" "sync" "time" - "github.com/mitchellh/mapstructure" - "github.com/pkg/errors" "github.com/rs/zerolog" ) @@ -48,45 +45,37 @@ type NewService func(conf map[string]interface{}, log *zerolog.Logger) (Service, // Serverless contains the serveless collection of services. type Serverless struct { - conf *config - log zerolog.Logger - services map[string]Service -} - -type config struct { - Services map[string]map[string]interface{} `mapstructure:"services"` + log *zerolog.Logger + Services map[string]Service } // New returns a new serverless collection of services. -func New(m interface{}, l zerolog.Logger) (*Serverless, error) { - conf := &config{} - if err := mapstructure.Decode(m, conf); err != nil { - return nil, err - } - +func New(opt ...Option) (*Serverless, error) { + l := zerolog.Nop() n := &Serverless{ - conf: conf, - log: l, - services: map[string]Service{}, + Services: map[string]Service{}, + log: &l, + } + for _, o := range opt { + o(n) } return n, nil } -func (s *Serverless) isServiceEnabled(svcName string) bool { - _, ok := Services[svcName] - return ok -} - // Start starts the serverless service collection. func (s *Serverless) Start() error { - return s.registerAndStartServices() + for name, svc := range s.Services { + go svc.Start() + s.log.Info().Msgf("serverless service enabled: %s", name) + } + return nil } // GracefulStop gracefully stops the serverless services. func (s *Serverless) GracefulStop() error { var wg sync.WaitGroup - for svcName, svc := range s.services { + for svcName, svc := range s.Services { wg.Add(1) go func(svcName string, svc Service) { @@ -113,7 +102,7 @@ func (s *Serverless) GracefulStop() error { func (s *Serverless) Stop() error { var wg sync.WaitGroup - for svcName, svc := range s.services { + for svcName, svc := range s.Services { wg.Add(1) go func(svcName string, svc Service) { @@ -136,26 +125,3 @@ func (s *Serverless) Stop() error { return nil } - -func (s *Serverless) registerAndStartServices() error { - for svcName := range s.conf.Services { - if s.isServiceEnabled(svcName) { - newFunc := Services[svcName] - svcLogger := s.log.With().Str("service", svcName).Logger() - svc, err := newFunc(s.conf.Services[svcName], &svcLogger) - if err != nil { - return errors.Wrapf(err, "serverless service %s could not be initialized", svcName) - } - - go svc.Start() - - s.services[svcName] = svc - - s.log.Info().Msgf("serverless service enabled: %s", svcName) - } else { - return fmt.Errorf("serverless service %s does not exist", svcName) - } - } - - return nil -} diff --git a/pkg/sharedconf/sharedconf.go b/pkg/sharedconf/sharedconf.go index 0503a5a806..c40aa6c0c0 100644 --- a/pkg/sharedconf/sharedconf.go +++ b/pkg/sharedconf/sharedconf.go @@ -19,50 +19,18 @@ package sharedconf import ( - "fmt" - "os" + "sync" - "github.com/mitchellh/mapstructure" + "github.com/cs3org/reva/cmd/revad/pkg/config" ) -var sharedConf = &conf{} +var sharedConf *config.Shared = &config.Shared{} +var once sync.Once -type conf struct { - JWTSecret string `mapstructure:"jwt_secret"` - GatewaySVC string `mapstructure:"gatewaysvc"` - DataGateway string `mapstructure:"datagateway"` - SkipUserGroupsInToken bool `mapstructure:"skip_user_groups_in_token"` - BlockedUsers []string `mapstructure:"blocked_users"` -} - -// Decode decodes the configuration. -func Decode(v interface{}) error { - if err := mapstructure.Decode(v, sharedConf); err != nil { - return err - } - - // add some defaults - if sharedConf.GatewaySVC == "" { - sharedConf.GatewaySVC = "0.0.0.0:19000" - } - - // this is the default address we use for the data gateway HTTP service - if sharedConf.DataGateway == "" { - host, err := os.Hostname() - if err != nil || host == "" { - sharedConf.DataGateway = "http://0.0.0.0:19001/datagateway" - } else { - sharedConf.DataGateway = fmt.Sprintf("http://%s:19001/datagateway", host) - } - } - - // TODO(labkode): would be cool to autogenerate one secret and print - // it on init time. - if sharedConf.JWTSecret == "" { - sharedConf.JWTSecret = "changemeplease" - } - - return nil +func Init(c *config.Shared) { + once.Do(func() { + sharedConf = c + }) } // GetJWTSecret returns the package level configured jwt secret if not overwritten. diff --git a/pkg/sharedconf/sharedconf_test.go b/pkg/sharedconf/sharedconf_test.go index e6ad94a777..0140e045bf 100644 --- a/pkg/sharedconf/sharedconf_test.go +++ b/pkg/sharedconf/sharedconf_test.go @@ -23,42 +23,42 @@ import ( ) func Test(t *testing.T) { - conf := map[string]interface{}{ - "jwt_secret": "", - "gateway": "", - } + // conf := map[string]interface{}{ + // "jwt_secret": "", + // "gateway": "", + // } - err := Decode(conf) - if err != nil { - t.Fatal(err) - } + // err := Decode(conf) + // if err != nil { + // t.Fatal(err) + // } - got := GetJWTSecret("secret") - if got != "secret" { - t.Fatalf("expected %q got %q", "secret", got) - } + // got := GetJWTSecret("secret") + // if got != "secret" { + // t.Fatalf("expected %q got %q", "secret", got) + // } - got = GetJWTSecret("") - if got != "changemeplease" { - t.Fatalf("expected %q got %q", "changemeplease", got) - } + // got = GetJWTSecret("") + // if got != "changemeplease" { + // t.Fatalf("expected %q got %q", "changemeplease", got) + // } - conf = map[string]interface{}{ - "jwt_secret": "dummy", - } + // conf = map[string]interface{}{ + // "jwt_secret": "dummy", + // } - err = Decode(conf) - if err != nil { - t.Fatal(err) - } + // err = Decode(conf) + // if err != nil { + // t.Fatal(err) + // } - got = GetJWTSecret("secret") - if got != "secret" { - t.Fatalf("expected %q got %q", "secret", got) - } + // got = GetJWTSecret("secret") + // if got != "secret" { + // t.Fatalf("expected %q got %q", "secret", got) + // } - got = GetJWTSecret("") - if got != "dummy" { - t.Fatalf("expected %q got %q", "dummy", got) - } + // got = GetJWTSecret("") + // if got != "dummy" { + // t.Fatalf("expected %q got %q", "dummy", got) + // } } diff --git a/pkg/utils/list/list.go b/pkg/utils/list/list.go index 2b1962ab2b..de279ebf77 100644 --- a/pkg/utils/list/list.go +++ b/pkg/utils/list/list.go @@ -34,3 +34,25 @@ func Remove[T any](l []T, i int) []T { l[i] = l[len(l)-1] return l[:len(l)-1] } + +// TakeFirst returns the first elemen, if any, that satisfies +// the predicate p. +func TakeFirst[T any](l []T, p func(T) bool) (T, bool) { + for _, e := range l { + if p(e) { + return e, true + } + } + var z T + return z, false +} + +// ToMap returns a map from l where the keys are obtainined applying +// the func k to the elements of l. +func ToMap[K comparable, T any](l []T, k func(T) K) map[K]T { + m := make(map[K]T, len(l)) + for _, e := range l { + m[k(e)] = e + } + return m +} diff --git a/pkg/utils/maps/maps.go b/pkg/utils/maps/maps.go new file mode 100644 index 0000000000..0b04f84025 --- /dev/null +++ b/pkg/utils/maps/maps.go @@ -0,0 +1,50 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package maps + +// Merge returns a map containing the keys and values from both maps. +// If the two maps share a set of keys, the result map will contain +// only the value of the second map. +func Merge[K comparable, T any](m, n map[K]T) map[K]T { + r := make(map[K]T, len(m)+len(n)) + for k, v := range m { + r[k] = v + } + for k, v := range n { + r[k] = v + } + return r +} + +// MapValues returns a map with vales mapped using the function f. +func MapValues[K comparable, T, V any](m map[K]T, f func(T) V) map[K]V { + r := make(map[K]V, len(m)) + for k, v := range m { + r[k] = f(v) + } + return r +} + +func Keys[K comparable, V any](m map[K]V) []K { + l := make([]K, 0, len(m)) + for k := range m { + l = append(l, k) + } + return l +} diff --git a/pkg/utils/net/net.go b/pkg/utils/net/net.go new file mode 100644 index 0000000000..51b94117f0 --- /dev/null +++ b/pkg/utils/net/net.go @@ -0,0 +1,54 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package net + +import "net" + +// AddressEqual return true if the addresses are equal. +// For tpc addresses only the port is compared, for unix +// the name and net are compared. +func AddressEqual(a net.Addr, network, address string) bool { + if a.Network() != network { + return false + } + + switch network { + case "tcp": + t, err := net.ResolveTCPAddr(network, address) + if err != nil { + return false + } + return tcpAddressEqual(a.(*net.TCPAddr), t) + case "unix": + t, err := net.ResolveUnixAddr(network, address) + if err != nil { + return false + } + return unixAddressEqual(a.(*net.UnixAddr), t) + } + return false +} + +func tcpAddressEqual(a1, a2 *net.TCPAddr) bool { + return a1.Port == a2.Port +} + +func unixAddressEqual(a1, a2 *net.UnixAddr) bool { + return a1.Name == a2.Name && a1.Net == a2.Net +}