From cecc1c47a4a5bb7c48f7bea7e7deb4087b601b25 Mon Sep 17 00:00:00 2001 From: Gaius Date: Tue, 8 Aug 2023 18:43:49 +0800 Subject: [PATCH] fix: add NewSafeDialer and fix ssrf in manager preheat api Signed-off-by: Gaius --- manager/job/preheat.go | 2 ++ pkg/net/http/http.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/manager/job/preheat.go b/manager/job/preheat.go index 00cfe015423..3b8bcb1ddbf 100644 --- a/manager/job/preheat.go +++ b/manager/job/preheat.go @@ -220,6 +220,7 @@ func (p *preheat) getManifests(ctx context.Context, url string, header http.Head client := &http.Client{ Timeout: defaultHTTPRequesttimeout, Transport: &http.Transport{ + DialContext: nethttp.NewSafeDialer().DialContext, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } @@ -275,6 +276,7 @@ func getAuthToken(ctx context.Context, header http.Header) (string, error) { client := &http.Client{ Timeout: defaultHTTPRequesttimeout, Transport: &http.Transport{ + DialContext: nethttp.NewSafeDialer().DialContext, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } diff --git a/pkg/net/http/http.go b/pkg/net/http/http.go index 7333906b8d0..4c574539887 100644 --- a/pkg/net/http/http.go +++ b/pkg/net/http/http.go @@ -17,7 +17,16 @@ package http import ( + "fmt" + "net" "net/http" + "syscall" + "time" +) + +const ( + // DefaultDialTimeout is the default timeout for dialing a http connection. + DefaultDialTimeout = 30 * time.Second ) // HeaderToMap coverts request headers to map[string]string. @@ -48,3 +57,35 @@ func PickHeader(header http.Header, key, defaultValue string) string { return defaultValue } + +// NewSafeDialer returns a new net.Dialer with safe socket control. +func NewSafeDialer() *net.Dialer { + return &net.Dialer{ + Timeout: DefaultDialTimeout, + DualStack: true, + Control: safeSocketControl, + } +} + +// safeSocketControl restricts the socket to only connect to valid addresses. +func safeSocketControl(network string, address string, conn syscall.RawConn) error { + if !(network == "tcp4" || network == "tcp6") { + return fmt.Errorf("network type %s is invalid", network) + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return err + } + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("host %s is invalid", host) + } + + if !ip.IsGlobalUnicast() || ip.IsPrivate() { + return fmt.Errorf("ip %s is invalid", ip.String()) + } + + return nil +}