Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract pkg/webhook package from rpc/auth package #1133

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions cmd/yorkie/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ var (
mongoPingTimeout time.Duration

authWebhookMaxWaitInterval time.Duration
authWebhookCacheAuthTTL time.Duration
authWebhookCacheUnauthTTL time.Duration
authWebhookCacheTTL time.Duration
projectInfoCacheTTL time.Duration

conf = server.NewConfig()
Expand All @@ -65,8 +64,7 @@ func newServerCmd() *cobra.Command {
conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold

conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String()
conf.Backend.AuthWebhookCacheAuthTTL = authWebhookCacheAuthTTL.String()
conf.Backend.AuthWebhookCacheUnauthTTL = authWebhookCacheUnauthTTL.String()
conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String()
conf.Backend.ProjectInfoCacheTTL = projectInfoCacheTTL.String()

conf.Housekeeping.Interval = housekeepingInterval.String()
Expand Down Expand Up @@ -316,16 +314,10 @@ func init() {
"The cache size of the authorization webhook.",
)
cmd.Flags().DurationVar(
&authWebhookCacheAuthTTL,
&authWebhookCacheTTL,
"auth-webhook-cache-auth-ttl",
server.DefaultAuthWebhookCacheAuthTTL,
"TTL value to set when caching authorized webhook response.",
)
cmd.Flags().DurationVar(
&authWebhookCacheUnauthTTL,
"auth-webhook-cache-unauth-ttl",
server.DefaultAuthWebhookCacheUnauthTTL,
"TTL value to set when caching unauthorized webhook response.",
server.DefaultAuthWebhookCacheTTL,
"TTL value to set when caching authorization webhook response.",
)
cmd.Flags().IntVar(
&conf.Backend.ProjectInfoCacheSize,
Expand Down
24 changes: 24 additions & 0 deletions pkg/types/pair.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2025 The Yorkie Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package types provides the common types used in the Yorkie.
package types

// Pair is a pair of two values.
type Pair[F any, S any] struct {
First F
Second S
}
179 changes: 179 additions & 0 deletions pkg/webhook/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright 2025 The Yorkie Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package webhook provides a client for the webhook.
package webhook

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"syscall"
"time"

"github.com/yorkie-team/yorkie/pkg/cache"
"github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/server/logging"
)

var (
// ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook.
ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook")

// ErrUnexpectedResponse is returned when the response from the webhook is not as expected.
ErrUnexpectedResponse = errors.New("unexpected response from webhook")

// ErrWebhookTimeout is returned when the webhook does not respond in time.
ErrWebhookTimeout = errors.New("webhook timeout")
)

// Options are the options for the webhook client.
type Options struct {
CacheKeyPrefix string
CacheTTL time.Duration

MaxRetries uint64
MaxWaitInterval time.Duration
}

// Client is a client for the webhook.
type Client[Req any, Res any] struct {
cache *cache.LRUExpireCache[string, types.Pair[int, *Res]]
url string
options Options
}

// NewClient creates a new instance of Client.
func NewClient[Req any, Res any](
url string,
Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]],
options Options,
) *Client[Req, Res] {
return &Client[Req, Res]{
url: url,
cache: Cache,
options: options,
}
}

// Send sends the given request to the webhook.
func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, 0, fmt.Errorf("marshal webhook request: %w", err)
}

cacheKey := c.options.CacheKeyPrefix + ":" + string(body)
if entry, ok := c.cache.Get(cacheKey); ok {
return entry.Second, entry.First, nil
}

var res Res
status, err := c.withExponentialBackoff(ctx, func() (int, error) {
// TODO(hackerwins, window9u): We should consider using HMAC to sign the request.
resp, err := http.Post(
c.url,
"application/json",
bytes.NewBuffer(body),
)
if err != nil {
return 0, fmt.Errorf("post to webhook: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
logging.From(ctx).Error(err)
}
}()

if resp.StatusCode != http.StatusOK &&
resp.StatusCode != http.StatusUnauthorized &&
resp.StatusCode != http.StatusForbidden {
return resp.StatusCode, ErrUnexpectedStatusCode
}

if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return resp.StatusCode, ErrUnexpectedResponse
}

return resp.StatusCode, nil
})
if err != nil {
return nil, status, err
}

// TODO(hackerwins): We should consider caching the response of Unauthorized as well.
if status != http.StatusUnauthorized {
c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL)
}

return &res, status, nil
}

func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) {
var retries uint64
var statusCode int
for retries <= c.options.MaxRetries {
statusCode, err := webhookFn()
if !shouldRetry(statusCode, err) {
if err == ErrUnexpectedStatusCode {
return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode)
}

return statusCode, err
}

waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval)

select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(waitBeforeRetry):
}

retries++
}

return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout)
}

// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds.
func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration {
interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond
if maxWaitInterval < interval {
return maxWaitInterval
}

return interval
}

// shouldRetry returns true if the given error should be retried.
// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry
func shouldRetry(statusCode int, err error) bool {
// If the connection is reset, we should retry.
var errno syscall.Errno
if errors.As(err, &errno) {
return errno == syscall.ECONNRESET
}

return statusCode == http.StatusInternalServerError ||
statusCode == http.StatusServiceUnavailable ||
statusCode == http.StatusGatewayTimeout ||
statusCode == http.StatusTooManyRequests
}
21 changes: 13 additions & 8 deletions server/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/yorkie-team/yorkie/api/types"
"github.com/yorkie-team/yorkie/pkg/cache"
pkgtypes "github.com/yorkie-team/yorkie/pkg/types"
"github.com/yorkie-team/yorkie/server/backend/background"
"github.com/yorkie-team/yorkie/server/backend/database"
memdb "github.com/yorkie-team/yorkie/server/backend/database/memory"
Expand All @@ -43,9 +44,12 @@ import (
// Backend manages Yorkie's backend such as Database and Coordinator. And it
// has the server status such as the information of this Server.
type Backend struct {
Config *Config
serverInfo *sync.ServerInfo
AuthWebhookCache *cache.LRUExpireCache[string, *types.AuthWebhookResponse]
Config *Config
serverInfo *sync.ServerInfo
WebhookCache *cache.LRUExpireCache[string, pkgtypes.Pair[
int,
*types.AuthWebhookResponse,
]]

Metrics *prometheus.Metrics
DB database.Database
Expand Down Expand Up @@ -80,8 +84,9 @@ func New(

// 02. Create the auth webhook cache. The auth webhook cache is used to
// cache the response of the auth webhook.
// TODO(hackerwins): Consider to extend the cache for general purpose.
webhookCache, err := cache.NewLRUExpireCache[string, *types.AuthWebhookResponse](conf.AuthWebhookCacheSize)
webhookCache, err := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]](
conf.AuthWebhookCacheSize,
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -145,9 +150,9 @@ func New(
)

return &Backend{
Config: conf,
serverInfo: serverInfo,
AuthWebhookCache: webhookCache,
Config: conf,
serverInfo: serverInfo,
WebhookCache: webhookCache,

Metrics: metrics,
DB: db,
Expand Down
40 changes: 9 additions & 31 deletions server/backend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,8 @@ type Config struct {
// AuthWebhookCacheSize is the cache size of the authorization webhook.
AuthWebhookCacheSize int `yaml:"AuthWebhookCacheSize"`

// AuthWebhookCacheAuthTTL is the TTL value to set when caching the authorized result.
AuthWebhookCacheAuthTTL string `yaml:"AuthWebhookCacheAuthTTL"`

// AuthWebhookCacheUnauthTTL is the TTL value to set when caching the unauthorized result.
AuthWebhookCacheUnauthTTL string `yaml:"AuthWebhookCacheUnauthTTL"`
// AuthWebhookCacheTTL is the TTL value to set when caching the authorized result.
AuthWebhookCacheTTL string `yaml:"AuthWebhookCacheTTL"`

// ProjectInfoCacheSize is the cache size of the project info.
ProjectInfoCacheSize int `yaml:"ProjectInfoCacheSize"`
Expand Down Expand Up @@ -104,18 +101,10 @@ func (c *Config) Validate() error {
)
}

if _, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL); err != nil {
if _, err := time.ParseDuration(c.AuthWebhookCacheTTL); err != nil {
return fmt.Errorf(
`invalid argument "%s" for "--auth-webhook-cache-auth-ttl" flag: %w`,
c.AuthWebhookCacheAuthTTL,
err,
)
}

if _, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL); err != nil {
return fmt.Errorf(
`invalid argument "%s" for "--auth-webhook-cache-unauth-ttl" flag: %w`,
c.AuthWebhookCacheUnauthTTL,
`invalid argument "%s" for "--auth-webhook-cache-ttl" flag: %w`,
c.AuthWebhookCacheTTL,
err,
)
}
Expand Down Expand Up @@ -153,22 +142,11 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration {
return result
}

// ParseAuthWebhookCacheAuthTTL returns TTL for authorized cache.
func (c *Config) ParseAuthWebhookCacheAuthTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL)
if err != nil {
fmt.Fprintln(os.Stderr, "parse auth webhook cache auth ttl: %w", err)
os.Exit(1)
}

return result
}

// ParseAuthWebhookCacheUnauthTTL returns TTL for unauthorized cache.
func (c *Config) ParseAuthWebhookCacheUnauthTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL)
// ParseAuthWebhookCacheTTL returns TTL for authorized cache.
func (c *Config) ParseAuthWebhookCacheTTL() time.Duration {
result, err := time.ParseDuration(c.AuthWebhookCacheTTL)
if err != nil {
fmt.Fprintln(os.Stderr, "parse auth webhook cache unauth ttl: %w", err)
fmt.Fprintln(os.Stderr, "parse auth webhook cache ttl: %w", err)
os.Exit(1)
}

Expand Down
11 changes: 3 additions & 8 deletions server/backend/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ func TestConfig(t *testing.T) {
validConf := backend.Config{
ClientDeactivateThreshold: "1h",
AuthWebhookMaxWaitInterval: "0ms",
AuthWebhookCacheAuthTTL: "10s",
AuthWebhookCacheUnauthTTL: "10s",
AuthWebhookCacheTTL: "10s",
ProjectInfoCacheTTL: "10m",
}
assert.NoError(t, validConf.Validate())
Expand All @@ -44,15 +43,11 @@ func TestConfig(t *testing.T) {
assert.Error(t, conf2.Validate())

conf3 := validConf
conf3.AuthWebhookCacheAuthTTL = "s"
conf3.AuthWebhookCacheTTL = "s"
assert.Error(t, conf3.Validate())

conf4 := validConf
conf4.AuthWebhookCacheUnauthTTL = "s"
conf4.ProjectInfoCacheTTL = "10 minutes"
assert.Error(t, conf4.Validate())

conf5 := validConf
conf5.ProjectInfoCacheTTL = "10 minutes"
assert.Error(t, conf5.Validate())
})
}
Loading
Loading