Skip to content

Commit

Permalink
feat: add tls client config for preheat in manager (#2612)
Browse files Browse the repository at this point in the history
Signed-off-by: Gaius <[email protected]>
  • Loading branch information
gaius-qi authored Aug 8, 2023
1 parent da7e305 commit 7f67bc5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 21 deletions.
39 changes: 33 additions & 6 deletions manager/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Config struct {
// Cache configuration.
Cache CacheConfig `yaml:"cache" mapstructure:"cache"`

// Job configuration.
Job JobConfig `yaml:"job" mapstructure:"job"`

// ObjectStorage configuration.
ObjectStorage ObjectStorageConfig `yaml:"objectStorage" mapstructure:"objectStorage"`

Expand Down Expand Up @@ -137,14 +140,14 @@ type MysqlConfig struct {
// TLS mode (can be one of "true", "false", "skip-verify", or "preferred").
TLSConfig string `yaml:"tlsConfig" mapstructure:"tlsConfig"`

// Custom TLS configuration (overrides "TLSConfig" setting above).
TLS *MysqlTLSConfig `yaml:"tls" mapstructure:"tls"`
// Custom TLS client configuration (overrides "TLSConfig" setting above).
TLS *MysqlTLSClientConfig `yaml:"tls" mapstructure:"tls"`

// Enable migration.
Migrate bool `yaml:"migrate" mapstructure:"migrate"`
}

type MysqlTLSConfig struct {
type MysqlTLSClientConfig struct {
// Client certificate file path.
Cert string `yaml:"cert" mapstructure:"cert"`

Expand Down Expand Up @@ -239,11 +242,11 @@ type RESTConfig struct {
// REST server address.
Addr string `yaml:"addr" mapstructure:"addr"`

// TLS configuration.
TLS *RESTTLSConfig `yaml:"tls" mapstructure:"tls"`
// TLS server configuration.
TLS *TLSServerConfig `yaml:"tls" mapstructure:"tls"`
}

type RESTTLSConfig struct {
type TLSServerConfig struct {
// Certificate file path.
Cert string `yaml:"cert" mapstructure:"cert"`

Expand Down Expand Up @@ -281,6 +284,21 @@ type TCPListenPortRange struct {
End int
}

type JobConfig struct {
// Preheat configuration.
Preheat PreheatConfig `yaml:"preheat" mapstructure:"preheat"`
}

type PreheatConfig struct {
// TLS client configuration.
TLS *TLSClientConfig `yaml:"tls" mapstructure:"tls"`
}

type TLSClientConfig struct {
// CACert is the CA certificate for preheat tls handshake, it can be path or PEM format string.
CACert types.PEMContent `yaml:"caCert" mapstructure:"caCert"`
}

type ObjectStorageConfig struct {
// Enable object storage.
Enable bool `yaml:"enable" mapstructure:"enable"`
Expand Down Expand Up @@ -405,6 +423,9 @@ func New() *Config {
TTL: DefaultLFUCacheTTL,
},
},
Job: JobConfig{
Preheat: PreheatConfig{},
},
ObjectStorage: ObjectStorageConfig{
Enable: false,
S3ForcePathStyle: true,
Expand Down Expand Up @@ -575,6 +596,12 @@ func (cfg *Config) Validate() error {
return errors.New("local requires parameter ttl")
}

if cfg.Job.Preheat.TLS != nil {
if cfg.Job.Preheat.TLS.CACert == "" {
return errors.New("preheat requires parameter caCert")
}
}

if cfg.ObjectStorage.Enable {
if cfg.ObjectStorage.Name == "" {
return errors.New("objectStorage requires parameter name")
Expand Down
34 changes: 29 additions & 5 deletions manager/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var (
Migrate: true,
}

mockMysqlTLSConfig = &MysqlTLSConfig{
mockMysqlTLSConfig = &MysqlTLSClientConfig{
Cert: "ca.crt",
Key: "ca.key",
CA: "ca",
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestConfig_Load(t *testing.T) {
},
REST: RESTConfig{
Addr: ":8080",
TLS: &RESTTLSConfig{
TLS: &TLSServerConfig{
Cert: "foo",
Key: "foo",
},
Expand All @@ -152,7 +152,7 @@ func TestConfig_Load(t *testing.T) {
Port: 3306,
DBName: "foo",
TLSConfig: "preferred",
TLS: &MysqlTLSConfig{
TLS: &MysqlTLSClientConfig{
Cert: "foo",
Key: "foo",
CA: "foo",
Expand Down Expand Up @@ -188,6 +188,13 @@ func TestConfig_Load(t *testing.T) {
TTL: 1 * time.Second,
},
},
Job: JobConfig{
Preheat: PreheatConfig{
TLS: &TLSClientConfig{
CACert: "foo",
},
},
},
ObjectStorage: ObjectStorageConfig{
Enable: true,
Name: objectstorage.ServiceNameS3,
Expand Down Expand Up @@ -300,7 +307,7 @@ func TestConfig_Validate(t *testing.T) {
name: "rest tls requires parameter cert",
config: New(),
mock: func(cfg *Config) {
cfg.Server.REST.TLS = &RESTTLSConfig{
cfg.Server.REST.TLS = &TLSServerConfig{
Cert: "",
Key: "foo",
}
Expand All @@ -314,7 +321,7 @@ func TestConfig_Validate(t *testing.T) {
name: "rest tls requires parameter key",
config: New(),
mock: func(cfg *Config) {
cfg.Server.REST.TLS = &RESTTLSConfig{
cfg.Server.REST.TLS = &TLSServerConfig{
Cert: "foo",
Key: "",
}
Expand Down Expand Up @@ -703,6 +710,23 @@ func TestConfig_Validate(t *testing.T) {
assert.EqualError(err, "local requires parameter ttl")
},
},
{
name: "preheat requires parameter caCert",
config: New(),
mock: func(cfg *Config) {
cfg.Auth.JWT = mockJWTConfig
cfg.Database.Type = DatabaseTypeMysql
cfg.Database.Mysql = mockMysqlConfig
cfg.Database.Redis = mockRedisConfig
cfg.Job.Preheat.TLS = &TLSClientConfig{
CACert: "",
}
},
expect: func(t *testing.T, err error) {
assert := assert.New(t)
assert.EqualError(err, "preheat requires parameter caCert")
},
},
{
name: "objectStorage requires parameter name",
config: New(),
Expand Down
4 changes: 4 additions & 0 deletions manager/config/testdata/manager.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ cache:
size: 10000
ttl: 1s

job:
preheat:
tls:
caCert: testdata/ca.crt
objectStorage:
enable: true
name: s3
Expand Down
13 changes: 12 additions & 1 deletion manager/job/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package job

import (
"crypto/x509"
"errors"

internaljob "d7y.io/dragonfly/v2/internal/job"
"d7y.io/dragonfly/v2/manager/config"
)
Expand All @@ -39,7 +42,15 @@ func New(cfg *config.Config) (*Job, error) {
return nil, err
}

p, err := newPreheat(j)
var certPool *x509.CertPool
if cfg.Job.Preheat.TLS != nil {
certPool = x509.NewCertPool()
if !certPool.AppendCertsFromPEM([]byte(cfg.Job.Preheat.TLS.CACert)) {
return nil, errors.New("invalid CA Cert")
}
}

p, err := newPreheat(j, certPool)
if err != nil {
return nil, err
}
Expand Down
18 changes: 9 additions & 9 deletions manager/job/preheat.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package job
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -70,7 +71,8 @@ type Preheat interface {
}

type preheat struct {
job *internaljob.Job
job *internaljob.Job
rootCAs *x509.CertPool
}

type preheatImage struct {
Expand All @@ -80,10 +82,8 @@ type preheatImage struct {
tag string
}

func newPreheat(job *internaljob.Job) (Preheat, error) {
return &preheat{
job: job,
}, nil
func newPreheat(job *internaljob.Job, rootCAs *x509.CertPool) (Preheat, error) {
return &preheat{job, rootCAs}, nil
}

func (p *preheat) CreatePreheat(ctx context.Context, schedulers []models.Scheduler, json types.PreheatArgs) (*internaljob.GroupJobState, error) {
Expand Down Expand Up @@ -185,7 +185,7 @@ func (p *preheat) getLayers(ctx context.Context, url, tag, filter string, header

if resp.StatusCode/100 != 2 {
if resp.StatusCode == http.StatusUnauthorized {
token, err := getAuthToken(ctx, resp.Header)
token, err := getAuthToken(ctx, resp.Header, p.rootCAs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -221,7 +221,7 @@ func (p *preheat) getManifests(ctx context.Context, url string, header http.Head
Timeout: defaultHTTPRequesttimeout,
Transport: &http.Transport{
DialContext: nethttp.NewSafeDialer().DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
TLSClientConfig: &tls.Config{RootCAs: p.rootCAs},
},
}

Expand Down Expand Up @@ -259,7 +259,7 @@ func (p *preheat) parseLayers(resp *http.Response, url, tag, filter string, head
return layers, nil
}

func getAuthToken(ctx context.Context, header http.Header) (string, error) {
func getAuthToken(ctx context.Context, header http.Header, rootCAs *x509.CertPool) (string, error) {
ctx, span := tracer.Start(ctx, config.SpanAuthWithRegistry, trace.WithSpanKind(trace.SpanKindProducer))
defer span.End()

Expand All @@ -277,7 +277,7 @@ func getAuthToken(ctx context.Context, header http.Header) (string, error) {
Timeout: defaultHTTPRequesttimeout,
Transport: &http.Transport{
DialContext: nethttp.NewSafeDialer().DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
TLSClientConfig: &tls.Config{RootCAs: rootCAs},
},
}

Expand Down

0 comments on commit 7f67bc5

Please sign in to comment.