Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add NewSafeDialer and fix ssrf in manager preheat api #2611

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions manager/job/preheat.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ func (p *preheat) getManifests(ctx context.Context, url string, header http.Head
client := &http.Client{
Timeout: defaultHTTPRequesttimeout,
Transport: &http.Transport{
DialContext: nethttp.NewSafeDialer().DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
Expand Down Expand Up @@ -275,6 +276,7 @@ func getAuthToken(ctx context.Context, header http.Header) (string, error) {
client := &http.Client{
Timeout: defaultHTTPRequesttimeout,
Transport: &http.Transport{
DialContext: nethttp.NewSafeDialer().DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
Expand Down
41 changes: 41 additions & 0 deletions pkg/net/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
package http

import (
"fmt"
"net"
"net/http"
"syscall"
"time"
)

const (
// DefaultDialTimeout is the default timeout for dialing a http connection.
DefaultDialTimeout = 30 * time.Second
)

// HeaderToMap coverts request headers to map[string]string.
Expand Down Expand Up @@ -48,3 +57,35 @@ func PickHeader(header http.Header, key, defaultValue string) string {

return defaultValue
}

// NewSafeDialer returns a new net.Dialer with safe socket control.
func NewSafeDialer() *net.Dialer {
return &net.Dialer{
Timeout: DefaultDialTimeout,
DualStack: true,
Control: safeSocketControl,
}
}

// safeSocketControl restricts the socket to only connect to valid addresses.
func safeSocketControl(network string, address string, conn syscall.RawConn) error {
if !(network == "tcp4" || network == "tcp6") {
return fmt.Errorf("network type %s is invalid", network)
}

host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}

ip := net.ParseIP(host)
if ip == nil {
return fmt.Errorf("host %s is invalid", host)
}

if !ip.IsGlobalUnicast() || ip.IsPrivate() {
return fmt.Errorf("ip %s is invalid", ip.String())
}

return nil
}