diff --git a/backend/etcd.go b/backend/etcd.go index efeec88..a77d071 100644 --- a/backend/etcd.go +++ b/backend/etcd.go @@ -20,13 +20,18 @@ type EtcdBackend struct { // NewEtcdBackend ... func NewEtcdBackend(endpoints []string) (*EtcdBackend, error) { + defaultDialTimeout := 5 * time.Second + + dialTimeout := getEnvAsDuration("ETCD_DIAL_TIMEOUT", defaultDialTimeout) + cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, - DialTimeout: 5 * time.Second, + DialTimeout: dialTimeout, }) if err != nil { return nil, err } + return &EtcdBackend{ client: cli, }, nil diff --git a/backend/http.go b/backend/http.go index 70a23ec..2dc81e7 100644 --- a/backend/http.go +++ b/backend/http.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "log" "net" "net/http" + "os" "time" "wirey/pkg/utils" @@ -29,15 +31,24 @@ type HTTPBackend struct { // NewHTTPBackend ... func NewHTTPBackend(baseurl, wireyVersion string) (*HTTPBackend, error) { + defaultDialTimeout := 30 * time.Second + defaultTLSHandshakeTimeout := 30 * time.Second + defaultClientTimeout := 60 * time.Second + + dialTimeout := getEnvAsDuration("DIAL_TIMEOUT", defaultDialTimeout) + tlsHandshakeTimeout := getEnvAsDuration("TLS_HANDSHAKE_TIMEOUT", defaultTLSHandshakeTimeout) + clientTimeout := getEnvAsDuration("CLIENT_TIMEOUT", defaultClientTimeout) + var transportWithTimeout = &http.Transport{ Dial: (&net.Dialer{ - Timeout: 5 * time.Second, + Timeout: dialTimeout, }).Dial, - TLSHandshakeTimeout: 5 * time.Second, + TLSHandshakeTimeout: tlsHandshakeTimeout, } + return &HTTPBackend{ client: &http.Client{ - Timeout: time.Second * 10, + Timeout: clientTimeout, Transport: transportWithTimeout, }, baseurl: baseurl, @@ -45,6 +56,21 @@ func NewHTTPBackend(baseurl, wireyVersion string) (*HTTPBackend, error) { }, nil } +func getEnvAsDuration(envVar string, defaultValue time.Duration) time.Duration { + valueStr := os.Getenv(envVar) + if valueStr == "" { + return defaultValue + } + + parsedValue, err := time.ParseDuration(valueStr) + if err != nil { + log.Printf("Invalid duration format for %s: %s, using default value %v\n", envVar, valueStr, defaultValue) + return defaultValue + } + + return parsedValue +} + // Join ... func (b *HTTPBackend) Join(ifname string, p Peer) error { joinURL := fmt.Sprintf("%s/%s/%s", b.baseurl, ifname, utils.PublicKeySHA256(p.PublicKey)) diff --git a/backend/plumber.go b/backend/plumber.go index b528df6..17c328e 100644 --- a/backend/plumber.go +++ b/backend/plumber.go @@ -16,9 +16,10 @@ import ( log "github.com/sirupsen/logrus" + "wirey/pkg/wireguard" + "github.com/cenkalti/backoff/v4" "github.com/vishvananda/netlink" - "wirey/pkg/wireguard" ) const ( @@ -36,13 +37,40 @@ const ( errIntConversionPort = "error during port conversion to int: %s" ) +var ( + defaultMaxElapsedTime = 30 * time.Minute + defaultMaxInterval = 120 * time.Second + defaultJitterRange = 5 +) + // values used for exponentialBackoff -const ( - MaxElapsedTime = 15 * time.Minute - MaxInterval = 120 * time.Second - JitterRange = 5 +var ( + MaxElapsedTime time.Duration + MaxInterval time.Duration + JitterRange int ) +func init() { + MaxElapsedTime = getEnvAsDuration("MAX_ELAPSED_TIME", defaultMaxElapsedTime) + MaxInterval = getEnvAsDuration("MAX_INTERVAL", defaultMaxInterval) + JitterRange = getEnvAsInt("JITTER_RANGE", defaultJitterRange) +} + +func getEnvAsInt(envVar string, defaultValue int) int { + valueStr := os.Getenv(envVar) + if valueStr == "" { + return defaultValue + } + + parsedValue, err := strconv.Atoi(valueStr) + if err != nil { + log.Printf("Invalid integer format for %s: %s, using default value %d\n", envVar, valueStr, defaultValue) + return defaultValue + } + + return parsedValue +} + // Peer ... type Peer struct { PublicKey []byte