diff --git a/trainer/config/config.go b/trainer/config/config.go new file mode 100644 index 00000000000..785728895ba --- /dev/null +++ b/trainer/config/config.go @@ -0,0 +1,200 @@ +package trainer + +import ( + "errors" + "net" + "time" + + "d7y.io/dragonfly/v2/pkg/net/ip" + "d7y.io/dragonfly/v2/pkg/rpc" + "d7y.io/dragonfly/v2/pkg/slices" + "d7y.io/dragonfly/v2/pkg/types" +) + +type Config struct { + // Network configuration. + Network NetworkConfig `yaml:"network" mapstructure:"network"` + + // Server configuration. + Server ServerConfig `yaml:"server" mapstructure:"server"` + + // Metrics configuration. + Metrics MetricsConfig `yaml:"metrics" mapstructure:"metrics"` + + // Security configuration. + Security SecurityConfig `yaml:"security" mapstructure:"security"` + + // Manager configuration. + Manager ManagerConfig `yaml:"manager" mapstructure:"manager"` +} + +type NetworkConfig struct { + // EnableIPv6 enables ipv6 for server. + EnableIPv6 bool `yaml:"enableIPv6" mapstructure:"enableIPv6"` +} + +type ServerConfig struct { + // AdvertiseIP is advertise ip. + AdvertiseIP net.IP `yaml:"advertiseIP" mapstructure:"advertiseIP"` + + // AdvertisePort is advertise port. + AdvertisePort int `yaml:"advertisePort" mapstructure:"advertisePort"` + + // ListenIP is listen ip, like: 0.0.0.0, 192.168.0.1. + ListenIP net.IP `yaml:"listenIP" mapstructure:"listenIP"` + + // Server port. + Port int `yaml:"port" mapstructure:"port"` + + // Server log directory. + LogDir string `yaml:"logDir" mapstructure:"logDir"` + + // Server storage data directory. + DataDir string `yaml:"dataDir" mapstructure:"dataDir"` +} + +type MetricsConfig struct { + // Enable metrics service. + Enable bool `yaml:"enable" mapstructure:"enable"` + + // Metrics service address. + Addr string `yaml:"addr" mapstructure:"addr"` +} + +type SecurityConfig struct { + // AutoIssueCert indicates to issue client certificates for all grpc call + // if AutoIssueCert is false, any other option in Security will be ignored. + AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"` + + // CACert is the root CA certificate for all grpc tls handshake, it can be path or PEM format string. + CACert types.PEMContent `mapstructure:"caCert" yaml:"caCert"` + + // TLSVerify indicates to verify client certificates. + TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"` + + // TLSPolicy controls the grpc shandshake behaviors: + // force: both ClientHandshake and ServerHandshake are only support tls. + // prefer: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support tls. + // default: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support insecure (non-tls). + TLSPolicy string `mapstructure:"tlsPolicy" yaml:"tlsPolicy"` + + // CertSpec is the desired state of certificate. + CertSpec CertSpec `mapstructure:"certSpec" yaml:"certSpec"` +} + +type CertSpec struct { + // DNSNames is a list of dns names be set on the certificate. + DNSNames []string `mapstructure:"dnsNames" yaml:"dnsNames"` + + // IPAddresses is a list of ip addresses be set on the certificate. + IPAddresses []net.IP `mapstructure:"ipAddresses" yaml:"ipAddresses"` + + // ValidityPeriod is the validity period of certificate. + ValidityPeriod time.Duration `mapstructure:"validityPeriod" yaml:"validityPeriod"` +} + +type ManagerConfig struct { + // Addr is manager address. + Addr string `yaml:"addr" mapstructure:"addr"` +} + +// New default configuration. +func New() *Config { + return &Config{ + Network: NetworkConfig{ + EnableIPv6: DefaultNetworkEnableIPv6, + }, + Server: ServerConfig{ + AdvertisePort: DefaultServerAdvertisePort, + Port: DefaultServerPort, + }, + Metrics: MetricsConfig{ + Enable: false, + Addr: DefaultMetricsAddr, + }, + Security: SecurityConfig{ + AutoIssueCert: false, + TLSVerify: true, + TLSPolicy: rpc.PreferTLSPolicy, + CertSpec: CertSpec{ + DNSNames: DefaultCertDNSNames, + IPAddresses: DefaultCertIPAddresses, + ValidityPeriod: DefaultCertValidityPeriod, + }, + }, + Manager: ManagerConfig{}, + } +} + +// Validate config parameters. +func (cfg *Config) Validate() error { + if cfg.Server.AdvertiseIP == nil { + return errors.New("server requires parameter advertiseIP") + } + + if cfg.Server.AdvertisePort <= 0 { + return errors.New("server requires parameter advertisePort") + } + + if cfg.Server.ListenIP == nil { + return errors.New("server requires parameter listenIP") + } + + if cfg.Server.Port <= 0 { + return errors.New("server requires parameter port") + } + + if cfg.Metrics.Enable { + if cfg.Metrics.Addr == "" { + return errors.New("metrics requires parameter addr") + } + } + + if cfg.Security.AutoIssueCert { + if cfg.Security.CACert == "" { + return errors.New("security requires parameter caCert") + } + + if !slices.Contains([]string{rpc.DefaultTLSPolicy, rpc.ForceTLSPolicy, rpc.PreferTLSPolicy}, cfg.Security.TLSPolicy) { + return errors.New("security requires parameter tlsPolicy") + } + + if len(cfg.Security.CertSpec.IPAddresses) == 0 { + return errors.New("certSpec requires parameter ipAddresses") + } + + if len(cfg.Security.CertSpec.DNSNames) == 0 { + return errors.New("certSpec requires parameter dnsNames") + } + + if cfg.Security.CertSpec.ValidityPeriod <= 0 { + return errors.New("certSpec requires parameter validityPeriod") + } + } + + if cfg.Manager.Addr == "" { + return errors.New("manager requires parameter addr") + } + + return nil +} + +func (cfg *Config) Convert() error { + if cfg.Server.AdvertiseIP == nil { + if cfg.Network.EnableIPv6 { + cfg.Server.AdvertiseIP = ip.IPv6 + } else { + cfg.Server.AdvertiseIP = ip.IPv4 + } + } + + if cfg.Server.ListenIP == nil { + if cfg.Network.EnableIPv6 { + cfg.Server.ListenIP = net.IPv6zero + } else { + cfg.Server.ListenIP = net.IPv4zero + } + } + + return nil +} diff --git a/trainer/config/config_test.go b/trainer/config/config_test.go new file mode 100644 index 00000000000..fccbf8582e1 --- /dev/null +++ b/trainer/config/config_test.go @@ -0,0 +1,248 @@ +package trainer + +import ( + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" + + "d7y.io/dragonfly/v2/pkg/rpc" + "d7y.io/dragonfly/v2/pkg/types" +) + +var ( + mockManagerConfig = ManagerConfig{ + Addr: "localhost", + } + + mockMetricsConfig = MetricsConfig{ + Enable: true, + Addr: DefaultMetricsAddr, + } + + mockSecurityConfig = SecurityConfig{ + AutoIssueCert: true, + CACert: types.PEMContent("foo"), + TLSPolicy: rpc.PreferTLSPolicy, + CertSpec: CertSpec{ + DNSNames: DefaultCertDNSNames, + IPAddresses: DefaultCertIPAddresses, + ValidityPeriod: DefaultCertValidityPeriod, + }, + } +) + +func TestConfig_Load(t *testing.T) { + config := &Config{ + Network: NetworkConfig{ + EnableIPv6: true, + }, + Server: ServerConfig{ + AdvertiseIP: net.ParseIP("127.0.0.1"), + AdvertisePort: 9090, + ListenIP: net.ParseIP("0.0.0.0"), + Port: 9092, + LogDir: "foo", + DataDir: "foo", + }, + Metrics: MetricsConfig{ + Enable: false, + Addr: ":8000", + }, + Security: SecurityConfig{ + AutoIssueCert: true, + CACert: "foo", + TLSVerify: true, + TLSPolicy: "force", + CertSpec: CertSpec{ + DNSNames: []string{"foo"}, + IPAddresses: []net.IP{net.IPv4zero}, + ValidityPeriod: 10 * time.Minute, + }, + }, + Manager: ManagerConfig{ + Addr: "127.0.0.1:65003", + }, + } + + trainerConfigYAML := &Config{} + contentYAML, _ := os.ReadFile("./testdata/trainer.yaml") + if err := yaml.Unmarshal(contentYAML, &trainerConfigYAML); err != nil { + t.Fatal(err) + } + assert := assert.New(t) + assert.EqualValues(config, trainerConfigYAML) +} + +func TestConfig_Validate(t *testing.T) { + tests := []struct { + name string + config *Config + mock func(cfg *Config) + expect func(t *testing.T, err error) + }{ + { + name: "valid config", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.NoError(err) + }, + }, + { + name: "server requires parameter advertiseIP", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Server.AdvertiseIP = nil + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "server requires parameter advertiseIP") + }, + }, + { + name: "server requires parameter advertisePort", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Server.AdvertisePort = 0 + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "server requires parameter advertisePort") + }, + }, + { + name: "server requires parameter listenIP", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Server.ListenIP = nil + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "server requires parameter listenIP") + }, + }, + { + name: "server requires parameter port", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Server.Port = 0 + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "server requires parameter port") + }, + }, + { + name: "metrics requires parameter addr", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Metrics = mockMetricsConfig + cfg.Metrics.Addr = "" + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "metrics requires parameter addr") + }, + }, + { + name: "security requires parameter caCert", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Security = mockSecurityConfig + cfg.Security.CACert = "" + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "security requires parameter caCert") + }, + }, + { + name: "security requires parameter tlsPolicy", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Security = mockSecurityConfig + cfg.Security.TLSPolicy = "" + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "security requires parameter tlsPolicy") + }, + }, + { + name: "certSpec requires parameter ipAddresses", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Security = mockSecurityConfig + cfg.Security.CertSpec.IPAddresses = []net.IP{} + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "certSpec requires parameter ipAddresses") + }, + }, + { + name: "certSpec requires parameter dnsNames", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Security = mockSecurityConfig + cfg.Security.CertSpec.DNSNames = []string{} + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "certSpec requires parameter dnsNames") + }, + }, + { + name: "certSpec requires parameter validityPeriod", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Security = mockSecurityConfig + cfg.Security.CertSpec.ValidityPeriod = 0 + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "certSpec requires parameter validityPeriod") + }, + }, + { + name: "manager requires parameter addr", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Manager.Addr = "" + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "manager requires parameter addr") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if err := tc.config.Convert(); err != nil { + t.Fatal(err) + } + + tc.mock(tc.config) + tc.expect(t, tc.config.Validate()) + }) + } +} diff --git a/trainer/config/constants.go b/trainer/config/constants.go new file mode 100644 index 00000000000..ffe24fabd8a --- /dev/null +++ b/trainer/config/constants.go @@ -0,0 +1,37 @@ +package trainer + +import ( + "net" + "time" + + "d7y.io/dragonfly/v2/pkg/net/ip" +) + +const ( + // DefaultServerPort is default port for server. + DefaultServerPort = 9092 + + // DefaultServerAdvertisePort is default advertise port for server. + DefaultServerAdvertisePort = 9090 +) + +const ( + // DefaultMetricsAddr is default address for metrics server. + DefaultMetricsAddr = ":8000" +) + +var ( + // DefaultCertIPAddresses is default ip addresses of certificate. + DefaultCertIPAddresses = []net.IP{ip.IPv4, ip.IPv6} + + // DefaultCertDNSNames is default dns names of certificate. + DefaultCertDNSNames = []string{"dragonfly-trainer", "dragonfly-trainer.dragonfly-system.svc", "dragonfly-trainer.dragonfly-system.svc.cluster.local"} + + // DefaultCertValidityPeriod is default validity period of certificate. + DefaultCertValidityPeriod = 180 * 24 * time.Hour +) + +var ( + // DefaultNetworkEnableIPv6 is default value of enableIPv6. + DefaultNetworkEnableIPv6 = false +) diff --git a/trainer/config/testdata/ca.crt b/trainer/config/testdata/ca.crt new file mode 100644 index 00000000000..257cc5642cb --- /dev/null +++ b/trainer/config/testdata/ca.crt @@ -0,0 +1 @@ +foo diff --git a/trainer/config/testdata/trainer.yaml b/trainer/config/testdata/trainer.yaml new file mode 100644 index 00000000000..58f61a9c8e8 --- /dev/null +++ b/trainer/config/testdata/trainer.yaml @@ -0,0 +1,30 @@ +network: + enableIPv6: true + +server: + advertiseIP: 127.0.0.1 + advertisePort: 9090 + listenIP: 0.0.0.0 + port: 9092 + host: foo + logDir: foo + dataDir: foo + +metrics: + enable: false + addr: ":8000" + +security: + autoIssueCert: true + caCert: testdata/ca.crt + tlsVerify: true + tlsPolicy: force + certSpec: + dnsNames: + - foo + ipAddresses: + - 0.0.0.0 + validityPeriod: 10m + +manager: + addr: 127.0.0.1:65003