From 7f67bc52ab5b59657eda383c99235e0c48a48493 Mon Sep 17 00:00:00 2001 From: Gaius Date: Tue, 8 Aug 2023 22:54:19 +0800 Subject: [PATCH] feat: add tls client config for preheat in manager (#2612) Signed-off-by: Gaius --- manager/config/config.go | 39 +++++++++++++++++++++++----- manager/config/config_test.go | 34 ++++++++++++++++++++---- manager/config/testdata/manager.yaml | 4 +++ manager/job/job.go | 13 +++++++++- manager/job/preheat.go | 18 ++++++------- 5 files changed, 87 insertions(+), 21 deletions(-) diff --git a/manager/config/config.go b/manager/config/config.go index ab6a71e2d70..d0ee305138d 100644 --- a/manager/config/config.go +++ b/manager/config/config.go @@ -46,6 +46,9 @@ type Config struct { // Cache configuration. Cache CacheConfig `yaml:"cache" mapstructure:"cache"` + // Job configuration. + Job JobConfig `yaml:"job" mapstructure:"job"` + // ObjectStorage configuration. ObjectStorage ObjectStorageConfig `yaml:"objectStorage" mapstructure:"objectStorage"` @@ -137,14 +140,14 @@ type MysqlConfig struct { // TLS mode (can be one of "true", "false", "skip-verify", or "preferred"). TLSConfig string `yaml:"tlsConfig" mapstructure:"tlsConfig"` - // Custom TLS configuration (overrides "TLSConfig" setting above). - TLS *MysqlTLSConfig `yaml:"tls" mapstructure:"tls"` + // Custom TLS client configuration (overrides "TLSConfig" setting above). + TLS *MysqlTLSClientConfig `yaml:"tls" mapstructure:"tls"` // Enable migration. Migrate bool `yaml:"migrate" mapstructure:"migrate"` } -type MysqlTLSConfig struct { +type MysqlTLSClientConfig struct { // Client certificate file path. Cert string `yaml:"cert" mapstructure:"cert"` @@ -239,11 +242,11 @@ type RESTConfig struct { // REST server address. Addr string `yaml:"addr" mapstructure:"addr"` - // TLS configuration. - TLS *RESTTLSConfig `yaml:"tls" mapstructure:"tls"` + // TLS server configuration. + TLS *TLSServerConfig `yaml:"tls" mapstructure:"tls"` } -type RESTTLSConfig struct { +type TLSServerConfig struct { // Certificate file path. Cert string `yaml:"cert" mapstructure:"cert"` @@ -281,6 +284,21 @@ type TCPListenPortRange struct { End int } +type JobConfig struct { + // Preheat configuration. + Preheat PreheatConfig `yaml:"preheat" mapstructure:"preheat"` +} + +type PreheatConfig struct { + // TLS client configuration. + TLS *TLSClientConfig `yaml:"tls" mapstructure:"tls"` +} + +type TLSClientConfig struct { + // CACert is the CA certificate for preheat tls handshake, it can be path or PEM format string. + CACert types.PEMContent `yaml:"caCert" mapstructure:"caCert"` +} + type ObjectStorageConfig struct { // Enable object storage. Enable bool `yaml:"enable" mapstructure:"enable"` @@ -405,6 +423,9 @@ func New() *Config { TTL: DefaultLFUCacheTTL, }, }, + Job: JobConfig{ + Preheat: PreheatConfig{}, + }, ObjectStorage: ObjectStorageConfig{ Enable: false, S3ForcePathStyle: true, @@ -575,6 +596,12 @@ func (cfg *Config) Validate() error { return errors.New("local requires parameter ttl") } + if cfg.Job.Preheat.TLS != nil { + if cfg.Job.Preheat.TLS.CACert == "" { + return errors.New("preheat requires parameter caCert") + } + } + if cfg.ObjectStorage.Enable { if cfg.ObjectStorage.Name == "" { return errors.New("objectStorage requires parameter name") diff --git a/manager/config/config_test.go b/manager/config/config_test.go index 43d44c2ad91..4000078a7eb 100644 --- a/manager/config/config_test.go +++ b/manager/config/config_test.go @@ -48,7 +48,7 @@ var ( Migrate: true, } - mockMysqlTLSConfig = &MysqlTLSConfig{ + mockMysqlTLSConfig = &MysqlTLSClientConfig{ Cert: "ca.crt", Key: "ca.key", CA: "ca", @@ -129,7 +129,7 @@ func TestConfig_Load(t *testing.T) { }, REST: RESTConfig{ Addr: ":8080", - TLS: &RESTTLSConfig{ + TLS: &TLSServerConfig{ Cert: "foo", Key: "foo", }, @@ -152,7 +152,7 @@ func TestConfig_Load(t *testing.T) { Port: 3306, DBName: "foo", TLSConfig: "preferred", - TLS: &MysqlTLSConfig{ + TLS: &MysqlTLSClientConfig{ Cert: "foo", Key: "foo", CA: "foo", @@ -188,6 +188,13 @@ func TestConfig_Load(t *testing.T) { TTL: 1 * time.Second, }, }, + Job: JobConfig{ + Preheat: PreheatConfig{ + TLS: &TLSClientConfig{ + CACert: "foo", + }, + }, + }, ObjectStorage: ObjectStorageConfig{ Enable: true, Name: objectstorage.ServiceNameS3, @@ -300,7 +307,7 @@ func TestConfig_Validate(t *testing.T) { name: "rest tls requires parameter cert", config: New(), mock: func(cfg *Config) { - cfg.Server.REST.TLS = &RESTTLSConfig{ + cfg.Server.REST.TLS = &TLSServerConfig{ Cert: "", Key: "foo", } @@ -314,7 +321,7 @@ func TestConfig_Validate(t *testing.T) { name: "rest tls requires parameter key", config: New(), mock: func(cfg *Config) { - cfg.Server.REST.TLS = &RESTTLSConfig{ + cfg.Server.REST.TLS = &TLSServerConfig{ Cert: "foo", Key: "", } @@ -703,6 +710,23 @@ func TestConfig_Validate(t *testing.T) { assert.EqualError(err, "local requires parameter ttl") }, }, + { + name: "preheat requires parameter caCert", + config: New(), + mock: func(cfg *Config) { + cfg.Auth.JWT = mockJWTConfig + cfg.Database.Type = DatabaseTypeMysql + cfg.Database.Mysql = mockMysqlConfig + cfg.Database.Redis = mockRedisConfig + cfg.Job.Preheat.TLS = &TLSClientConfig{ + CACert: "", + } + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "preheat requires parameter caCert") + }, + }, { name: "objectStorage requires parameter name", config: New(), diff --git a/manager/config/testdata/manager.yaml b/manager/config/testdata/manager.yaml index d8c78a9e755..a2bc25a71ca 100644 --- a/manager/config/testdata/manager.yaml +++ b/manager/config/testdata/manager.yaml @@ -63,6 +63,10 @@ cache: size: 10000 ttl: 1s +job: + preheat: + tls: + caCert: testdata/ca.crt objectStorage: enable: true name: s3 diff --git a/manager/job/job.go b/manager/job/job.go index 13cdffbb0b8..ba65097db46 100644 --- a/manager/job/job.go +++ b/manager/job/job.go @@ -17,6 +17,9 @@ package job import ( + "crypto/x509" + "errors" + internaljob "d7y.io/dragonfly/v2/internal/job" "d7y.io/dragonfly/v2/manager/config" ) @@ -39,7 +42,15 @@ func New(cfg *config.Config) (*Job, error) { return nil, err } - p, err := newPreheat(j) + var certPool *x509.CertPool + if cfg.Job.Preheat.TLS != nil { + certPool = x509.NewCertPool() + if !certPool.AppendCertsFromPEM([]byte(cfg.Job.Preheat.TLS.CACert)) { + return nil, errors.New("invalid CA Cert") + } + } + + p, err := newPreheat(j, certPool) if err != nil { return nil, err } diff --git a/manager/job/preheat.go b/manager/job/preheat.go index 3b8bcb1ddbf..b80548856eb 100644 --- a/manager/job/preheat.go +++ b/manager/job/preheat.go @@ -21,6 +21,7 @@ package job import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -70,7 +71,8 @@ type Preheat interface { } type preheat struct { - job *internaljob.Job + job *internaljob.Job + rootCAs *x509.CertPool } type preheatImage struct { @@ -80,10 +82,8 @@ type preheatImage struct { tag string } -func newPreheat(job *internaljob.Job) (Preheat, error) { - return &preheat{ - job: job, - }, nil +func newPreheat(job *internaljob.Job, rootCAs *x509.CertPool) (Preheat, error) { + return &preheat{job, rootCAs}, nil } func (p *preheat) CreatePreheat(ctx context.Context, schedulers []models.Scheduler, json types.PreheatArgs) (*internaljob.GroupJobState, error) { @@ -185,7 +185,7 @@ func (p *preheat) getLayers(ctx context.Context, url, tag, filter string, header if resp.StatusCode/100 != 2 { if resp.StatusCode == http.StatusUnauthorized { - token, err := getAuthToken(ctx, resp.Header) + token, err := getAuthToken(ctx, resp.Header, p.rootCAs) if err != nil { return nil, err } @@ -221,7 +221,7 @@ func (p *preheat) getManifests(ctx context.Context, url string, header http.Head Timeout: defaultHTTPRequesttimeout, Transport: &http.Transport{ DialContext: nethttp.NewSafeDialer().DialContext, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + TLSClientConfig: &tls.Config{RootCAs: p.rootCAs}, }, } @@ -259,7 +259,7 @@ func (p *preheat) parseLayers(resp *http.Response, url, tag, filter string, head return layers, nil } -func getAuthToken(ctx context.Context, header http.Header) (string, error) { +func getAuthToken(ctx context.Context, header http.Header, rootCAs *x509.CertPool) (string, error) { ctx, span := tracer.Start(ctx, config.SpanAuthWithRegistry, trace.WithSpanKind(trace.SpanKindProducer)) defer span.End() @@ -277,7 +277,7 @@ func getAuthToken(ctx context.Context, header http.Header) (string, error) { Timeout: defaultHTTPRequesttimeout, Transport: &http.Transport{ DialContext: nethttp.NewSafeDialer().DialContext, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + TLSClientConfig: &tls.Config{RootCAs: rootCAs}, }, }