Skip to content

Commit

Permalink
Extract pkg/webhook package from rpc/auth package
Browse files Browse the repository at this point in the history
  • Loading branch information
hackerwins committed Feb 3, 2025
1 parent 3bd66ce commit ff9ff34
Show file tree
Hide file tree
Showing 13 changed files with 291 additions and 231 deletions.
18 changes: 5 additions & 13 deletions cmd/yorkie/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ var (
mongoPingTimeout time.Duration

authWebhookMaxWaitInterval time.Duration
authWebhookCacheAuthTTL time.Duration
authWebhookCacheUnauthTTL time.Duration
authWebhookCacheTTL time.Duration
projectInfoCacheTTL time.Duration

conf = server.NewConfig()
Expand All @@ -65,8 +64,7 @@ func newServerCmd() *cobra.Command {
conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold

conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String()
conf.Backend.AuthWebhookCacheAuthTTL = authWebhookCacheAuthTTL.String()
conf.Backend.AuthWebhookCacheUnauthTTL = authWebhookCacheUnauthTTL.String()
conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String()
conf.Backend.ProjectInfoCacheTTL = projectInfoCacheTTL.String()

conf.Housekeeping.Interval = housekeepingInterval.String()
Expand Down Expand Up @@ -316,16 +314,10 @@ func init() {
"The cache size of the authorization webhook.",
)
cmd.Flags().DurationVar(
&authWebhookCacheAuthTTL,
&authWebhookCacheTTL,
"auth-webhook-cache-auth-ttl",
server.DefaultAuthWebhookCacheAuthTTL,
"TTL value to set when caching authorized webhook response.",
)
cmd.Flags().DurationVar(
&authWebhookCacheUnauthTTL,
"auth-webhook-cache-unauth-ttl",
server.DefaultAuthWebhookCacheUnauthTTL,
"TTL value to set when caching unauthorized webhook response.",
server.DefaultAuthWebhookCacheTTL,
"TTL value to set when caching authorization webhook response.",
)
cmd.Flags().IntVar(
&conf.Backend.ProjectInfoCacheSize,
Expand Down
24 changes: 24 additions & 0 deletions pkg/types/pair.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2025 The Yorkie Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package types provides the common types used in the Yorkie.
package types

// Pair is a pair of two values.
type Pair[F any, S any] struct {
First F
Second S
}
178 changes: 178 additions & 0 deletions pkg/webhook/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright 2025 The Yorkie Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package webhook provides a client for the webhook.
package webhook

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"syscall"
"time"

"github.com/yorkie-team/yorkie/pkg/cache"
"github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/server/logging"
)

var (
// ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook.
ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook")

// ErrUnexpectedResponse is returned when the response from the webhook is not as expected.
ErrUnexpectedResponse = errors.New("unexpected response from webhook")

// ErrWebhookTimeout is returned when the webhook does not respond in time.
ErrWebhookTimeout = errors.New("webhook timeout")
)

// Options are the options for the webhook client.
type Options struct {
CacheTTL time.Duration

MaxRetries uint64
MaxWaitInterval time.Duration
}

// Client is a client for the webhook.
type Client[Req any, Res any] struct {
url string
cache *cache.LRUExpireCache[string, types.Pair[int, *Res]]
options Options
}

// NewClient creates a new instance of Client.
func NewClient[Req any, Res any](
url string,
Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]],
options Options,
) *Client[Req, Res] {
return &Client[Req, Res]{
url: url,
cache: Cache,
options: options,
}
}

// Send sends the given request to the webhook.
func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, 0, fmt.Errorf("marshal webhook request: %w", err)
}

cacheKey := string(body)
if entry, ok := c.cache.Get(cacheKey); ok {
return entry.Second, entry.First, nil
}

var res Res
status, err := c.withExponentialBackoff(ctx, func() (int, error) {
// TODO(hackerwins, window9u): We should consider using HMAC to sign the request.
resp, err := http.Post(
c.url,
"application/json",
bytes.NewBuffer(body),
)
if err != nil {
return 0, fmt.Errorf("post to webhook: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
logging.From(ctx).Error(err)
}
}()

if resp.StatusCode != http.StatusOK &&
resp.StatusCode != http.StatusUnauthorized &&
resp.StatusCode != http.StatusForbidden {
return resp.StatusCode, ErrUnexpectedStatusCode
}

if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return resp.StatusCode, ErrUnexpectedResponse
}

return resp.StatusCode, nil
})
if err != nil {
return nil, status, err
}

// TODO(hackerwins): We should consider caching the response of Unauthorized as well.
if status != http.StatusUnauthorized {
c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL)
}

return &res, status, nil
}

func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) {
var retries uint64
var statusCode int
for retries <= c.options.MaxRetries {
statusCode, err := webhookFn()
if !shouldRetry(statusCode, err) {
if err == ErrUnexpectedStatusCode {
return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode)
}

return statusCode, err
}

waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval)

select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(waitBeforeRetry):
}

retries++
}

return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout)
}

// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds.
func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration {
interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond
if maxWaitInterval < interval {
return maxWaitInterval
}

return interval
}

// shouldRetry returns true if the given error should be retried.
// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry
func shouldRetry(statusCode int, err error) bool {
// If the connection is reset, we should retry.
var errno syscall.Errno
if errors.As(err, &errno) {
return errno == syscall.ECONNRESET
}

return statusCode == http.StatusInternalServerError ||
statusCode == http.StatusServiceUnavailable ||
statusCode == http.StatusGatewayTimeout ||
statusCode == http.StatusTooManyRequests
}
21 changes: 13 additions & 8 deletions server/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/yorkie-team/yorkie/api/types"
"github.com/yorkie-team/yorkie/pkg/cache"
pkgtypes "github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/server/backend/background"
"github.com/yorkie-team/yorkie/server/backend/database"
memdb "github.com/yorkie-team/yorkie/server/backend/database/memory"
Expand All @@ -43,9 +44,12 @@ import (
// Backend manages Yorkie's backend such as Database and Coordinator. And it
// has the server status such as the information of this Server.
type Backend struct {
Config *Config
serverInfo *sync.ServerInfo
AuthWebhookCache *cache.LRUExpireCache[string, *types.AuthWebhookResponse]
Config *Config
serverInfo *sync.ServerInfo
WebhookCache *cache.LRUExpireCache[string, pkgtypes.Pair[
int,
*types.AuthWebhookResponse,
]]

Metrics *prometheus.Metrics
DB database.Database
Expand Down Expand Up @@ -80,8 +84,9 @@ func New(

// 02. Create the auth webhook cache. The auth webhook cache is used to
// cache the response of the auth webhook.
// TODO(hackerwins): Consider to extend the cache for general purpose.
webhookCache, err := cache.NewLRUExpireCache[string, *types.AuthWebhookResponse](conf.AuthWebhookCacheSize)
webhookCache, err := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]](
conf.AuthWebhookCacheSize,
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -145,9 +150,9 @@ func New(
)

return &Backend{
Config: conf,
serverInfo: serverInfo,
AuthWebhookCache: webhookCache,
Config: conf,
serverInfo: serverInfo,
WebhookCache: webhookCache,

Metrics: metrics,
DB: db,
Expand Down
40 changes: 9 additions & 31 deletions server/backend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,8 @@ type Config struct {
// AuthWebhookCacheSize is the cache size of the authorization webhook.
AuthWebhookCacheSize int `yaml:"AuthWebhookCacheSize"`

// AuthWebhookCacheAuthTTL is the TTL value to set when caching the authorized result.
AuthWebhookCacheAuthTTL string `yaml:"AuthWebhookCacheAuthTTL"`

// AuthWebhookCacheUnauthTTL is the TTL value to set when caching the unauthorized result.
AuthWebhookCacheUnauthTTL string `yaml:"AuthWebhookCacheUnauthTTL"`
// AuthWebhookCacheTTL is the TTL value to set when caching the authorized result.
AuthWebhookCacheTTL string `yaml:"AuthWebhookCacheTTL"`

// ProjectInfoCacheSize is the cache size of the project info.
ProjectInfoCacheSize int `yaml:"ProjectInfoCacheSize"`
Expand Down Expand Up @@ -104,18 +101,10 @@ func (c *Config) Validate() error {
)
}

if _, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL); err != nil {
if _, err := time.ParseDuration(c.AuthWebhookCacheTTL); err != nil {
return fmt.Errorf(
`invalid argument "%s" for "--auth-webhook-cache-auth-ttl" flag: %w`,
c.AuthWebhookCacheAuthTTL,
err,
)
}

if _, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL); err != nil {
return fmt.Errorf(
`invalid argument "%s" for "--auth-webhook-cache-unauth-ttl" flag: %w`,
c.AuthWebhookCacheUnauthTTL,
`invalid argument "%s" for "--auth-webhook-cache-ttl" flag: %w`,
c.AuthWebhookCacheTTL,
err,
)
}
Expand Down Expand Up @@ -153,22 +142,11 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration {
return result
}

// ParseAuthWebhookCacheAuthTTL returns TTL for authorized cache.
func (c *Config) ParseAuthWebhookCacheAuthTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL)
if err != nil {
fmt.Fprintln(os.Stderr, "parse auth webhook cache auth ttl: %w", err)
os.Exit(1)
}

return result
}

// ParseAuthWebhookCacheUnauthTTL returns TTL for unauthorized cache.
func (c *Config) ParseAuthWebhookCacheUnauthTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL)
// ParseAuthWebhookCacheTTL returns TTL for authorized cache.
func (c *Config) ParseAuthWebhookCacheTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheTTL)
if err != nil {
fmt.Fprintln(os.Stderr, "parse auth webhook cache unauth ttl: %w", err)
fmt.Fprintln(os.Stderr, "parse auth webhook cache ttl: %w", err)
os.Exit(1)
}

Expand Down
11 changes: 3 additions & 8 deletions server/backend/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ func TestConfig(t *testing.T) {
validConf := backend.Config{
ClientDeactivateThreshold: "1h",
AuthWebhookMaxWaitInterval: "0ms",
AuthWebhookCacheAuthTTL: "10s",
AuthWebhookCacheUnauthTTL: "10s",
AuthWebhookCacheTTL: "10s",
ProjectInfoCacheTTL: "10m",
}
assert.NoError(t, validConf.Validate())
Expand All @@ -44,15 +43,11 @@ func TestConfig(t *testing.T) {
assert.Error(t, conf2.Validate())

conf3 := validConf
conf3.AuthWebhookCacheAuthTTL = "s"
conf3.AuthWebhookCacheTTL = "s"
assert.Error(t, conf3.Validate())

conf4 := validConf
conf4.AuthWebhookCacheUnauthTTL = "s"
conf4.ProjectInfoCacheTTL = "10 minutes"
assert.Error(t, conf4.Validate())

conf5 := validConf
conf5.ProjectInfoCacheTTL = "10 minutes"
assert.Error(t, conf5.Validate())
})
}
Loading

0 comments on commit ff9ff34

Please sign in to comment.