Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Request Timing Controls for Auth Webhook Client #1142

Merged
merged 18 commits into from
Feb 6, 2025
20 changes: 18 additions & 2 deletions cmd/yorkie/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ var (
mongoPingTimeout time.Duration

authWebhookMaxWaitInterval time.Duration
authWebhookMinWaitInterval time.Duration
authWebhookRequestTimeout time.Duration
authWebhookCacheTTL time.Duration
projectCacheTTL time.Duration

Expand All @@ -64,6 +66,8 @@ func newServerCmd() *cobra.Command {
conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold

conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String()
conf.Backend.AuthWebhookMinWaitInterval = authWebhookMinWaitInterval.String()
conf.Backend.AuthWebhookRequestTimeout = authWebhookRequestTimeout.String()
conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String()
conf.Backend.ProjectCacheTTL = projectCacheTTL.String()

Expand Down Expand Up @@ -295,17 +299,29 @@ func init() {
server.DefaultSnapshotDisableGC,
"Whether to disable garbage collection of snapshots.",
)
cmd.Flags().DurationVar(
&authWebhookRequestTimeout,
"auth-webhook-request-timeout",
server.DefaultAuthWebhookRequestTimeout,
"Timeout for each authorization webhook request.",
)
cmd.Flags().Uint64Var(
&conf.Backend.AuthWebhookMaxRetries,
"auth-webhook-max-retries",
server.DefaultAuthWebhookMaxRetries,
"Maximum number of retries for an authorization webhook.",
"Maximum number of retries for authorization webhook.",
)
cmd.Flags().DurationVar(
&authWebhookMinWaitInterval,
"auth-webhook-min-wait-interval",
server.DefaultAuthWebhookMinWaitInterval,
"Minimum wait interval between retries(exponential backoff).",
)
cmd.Flags().DurationVar(
&authWebhookMaxWaitInterval,
"auth-webhook-max-wait-interval",
server.DefaultAuthWebhookMaxWaitInterval,
"Maximum wait interval for authorization webhook.",
"Maximum wait interval between retries(exponential backoff).",
)
cmd.Flags().IntVar(
&conf.Backend.AuthWebhookCacheSize,
Expand Down
124 changes: 64 additions & 60 deletions pkg/webhook/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import (
"syscall"
"time"

"github.com/yorkie-team/yorkie/pkg/cache"
"github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/server/logging"
)

Expand All @@ -47,54 +45,53 @@ var (
ErrWebhookTimeout = errors.New("webhook timeout")
)

// Options are the options for the webhook client.
// Options are the options for the webhook httpClient.
type Options struct {
CacheKeyPrefix string
CacheTTL time.Duration

RequestTimeout time.Duration
MaxRetries uint64
MinWaitInterval time.Duration
MaxWaitInterval time.Duration

HMACKey string
}

// Client is a client for the webhook.
// Client is a httpClient for the webhook.
type Client[Req any, Res any] struct {
cache *cache.LRUExpireCache[string, types.Pair[int, *Res]]
url string
options Options
httpClient *http.Client
options Options
}

// NewClient creates a new instance of Client.
func NewClient[Req any, Res any](
url string,
Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]],
options Options,
) *Client[Req, Res] {
return &Client[Req, Res]{
url: url,
cache: Cache,
httpClient: &http.Client{
Timeout: options.RequestTimeout,
},
options: options,
}
}

// Send sends the given request to the webhook.
func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) {
body, err := json.Marshal(req)
func (c *Client[Req, Res]) Send(
ctx context.Context,
url, hmacKey string,
body []byte,
) (*Res, int, error) {
signature, err := createSignature(body, hmacKey)
if err != nil {
return nil, 0, fmt.Errorf("marshal webhook request: %w", err)
}

cacheKey := c.options.CacheKeyPrefix + ":" + string(body)
if entry, ok := c.cache.Get(cacheKey); ok {
return entry.Second, entry.First, nil
return nil, 0, fmt.Errorf("create signature: %w", err)
}

var res Res
status, err := c.withExponentialBackoff(ctx, func() (int, error) {
resp, err := c.post("application/json", body)
req, err := c.buildRequest(ctx, url, signature, body)
if err != nil {
return 0, fmt.Errorf("build request: %w", err)
}

resp, err := c.httpClient.Do(req)
if err != nil {
return 0, fmt.Errorf("post to webhook: %w", err)
return 0, fmt.Errorf("do request: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
Expand All @@ -103,9 +100,7 @@ func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error)
}
}()

if resp.StatusCode != http.StatusOK &&
resp.StatusCode != http.StatusUnauthorized &&
resp.StatusCode != http.StatusForbidden {
if !isExpectedStatus(resp.StatusCode) {
return resp.StatusCode, ErrUnexpectedStatusCode
}

Expand All @@ -119,55 +114,58 @@ func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error)
return nil, status, err
}

// TODO(hackerwins): We should consider caching the response of Unauthorized as well.
if status != http.StatusUnauthorized {
c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL)
}

return &res, status, nil
}

// post sends an HTTP POST request with HMAC-SHA256 signature headers.
// If key is empty, post sends an HTTP POST without signature.
func (c *Client[Req, Res]) post(contentType string, body []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(body))
// buildRequest creates a new HTTP POST request with the appropriate headers.
func (c *Client[Req, Res]) buildRequest(
ctx context.Context,
url, hmac string,
body []byte,
) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("create HTTP request: %w", err)
return nil, fmt.Errorf("create POST request with context: %w", err)
}

req.Header.Set("Content-Type", contentType)
if c.options.HMACKey != "" {
mac := hmac.New(sha256.New, []byte(c.options.HMACKey))
if _, err := mac.Write(body); err != nil {
return nil, fmt.Errorf("write HMAC body: %w", err)
}
signature := mac.Sum(nil)
signatureHex := hex.EncodeToString(signature) // Convert to hex string
req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex))
}
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send to %s: %w", c.url, err) // Wrapped with context
if hmac != "" {
req.Header.Set("X-Signature-256", hmac)
}

return resp, nil
return req, nil
}

// createSignature sets the HMAC signature header for the request.
func createSignature(data []byte, hmacKey string) (string, error) {
if hmacKey == "" {
return "", nil
}
mac := hmac.New(sha256.New, []byte(hmacKey))
if _, err := mac.Write(data); err != nil {
return "", fmt.Errorf("write HMAC body: %w", err)
}
signatureHex := hex.EncodeToString(mac.Sum(nil))
return fmt.Sprintf("sha256=%s", signatureHex), nil
}

func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) {
var retries uint64
var statusCode int
var err error

for retries <= c.options.MaxRetries {
statusCode, err := webhookFn()
statusCode, err = webhookFn()
if !shouldRetry(statusCode, err) {
if err == ErrUnexpectedStatusCode {
if errors.Is(err, ErrUnexpectedStatusCode) {
return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode)
}

return statusCode, err
}

waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval)
waitBeforeRetry := waitInterval(retries, c.options.MinWaitInterval, c.options.MaxWaitInterval)

select {
case <-ctx.Done():
Expand All @@ -181,9 +179,9 @@ func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn
return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout)
}

// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds.
func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration {
interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond
// waitInterval returns the interval of given retries. (2^retries * minWaitInterval) .
func waitInterval(retries uint64, minWaitInterval, maxWaitInterval time.Duration) time.Duration {
interval := time.Duration(math.Pow(2, float64(retries))) * minWaitInterval
if maxWaitInterval < interval {
return maxWaitInterval
}
Expand All @@ -197,11 +195,17 @@ func shouldRetry(statusCode int, err error) bool {
// If the connection is reset, we should retry.
var errno syscall.Errno
if errors.As(err, &errno) {
return errno == syscall.ECONNRESET
return errors.Is(errno, syscall.ECONNRESET)
}

return statusCode == http.StatusInternalServerError ||
statusCode == http.StatusServiceUnavailable ||
statusCode == http.StatusGatewayTimeout ||
statusCode == http.StatusTooManyRequests
}

func isExpectedStatus(statusCode int) bool {
return statusCode == http.StatusOK ||
statusCode == http.StatusUnauthorized ||
statusCode == http.StatusForbidden
}
Loading
Loading