diff --git a/cmd/yorkie/server.go b/cmd/yorkie/server.go index 54ea63094..f8673fd60 100644 --- a/cmd/yorkie/server.go +++ b/cmd/yorkie/server.go @@ -48,8 +48,7 @@ var ( mongoPingTimeout time.Duration authWebhookMaxWaitInterval time.Duration - authWebhookCacheAuthTTL time.Duration - authWebhookCacheUnauthTTL time.Duration + authWebhookCacheTTL time.Duration projectInfoCacheTTL time.Duration conf = server.NewConfig() @@ -65,8 +64,7 @@ func newServerCmd() *cobra.Command { conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String() - conf.Backend.AuthWebhookCacheAuthTTL = authWebhookCacheAuthTTL.String() - conf.Backend.AuthWebhookCacheUnauthTTL = authWebhookCacheUnauthTTL.String() + conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String() conf.Backend.ProjectInfoCacheTTL = projectInfoCacheTTL.String() conf.Housekeeping.Interval = housekeepingInterval.String() @@ -316,16 +314,10 @@ func init() { "The cache size of the authorization webhook.", ) cmd.Flags().DurationVar( - &authWebhookCacheAuthTTL, + &authWebhookCacheTTL, "auth-webhook-cache-auth-ttl", - server.DefaultAuthWebhookCacheAuthTTL, - "TTL value to set when caching authorized webhook response.", - ) - cmd.Flags().DurationVar( - &authWebhookCacheUnauthTTL, - "auth-webhook-cache-unauth-ttl", - server.DefaultAuthWebhookCacheUnauthTTL, - "TTL value to set when caching unauthorized webhook response.", + server.DefaultAuthWebhookCacheTTL, + "TTL value to set when caching authorization webhook response.", ) cmd.Flags().IntVar( &conf.Backend.ProjectInfoCacheSize, diff --git a/pkg/types/pair.go b/pkg/types/pair.go new file mode 100644 index 000000000..e7e065c2e --- /dev/null +++ b/pkg/types/pair.go @@ -0,0 +1,24 @@ +/* + * Copyright 2025 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package types provides the common types used in the Yorkie. +package types + +// Pair is a pair of two values. +type Pair[F any, S any] struct { + First F + Second S +} diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go new file mode 100644 index 000000000..d717ffd17 --- /dev/null +++ b/pkg/webhook/client.go @@ -0,0 +1,179 @@ +/* + * Copyright 2025 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package webhook provides a client for the webhook. +package webhook + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "math" + "net/http" + "syscall" + "time" + + "github.com/yorkie-team/yorkie/pkg/cache" + "github.com/yorkie-team/yorkie/pkg/types" + "github.com/yorkie-team/yorkie/server/logging" +) + +var ( + // ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook. + ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook") + + // ErrUnexpectedResponse is returned when the response from the webhook is not as expected. + ErrUnexpectedResponse = errors.New("unexpected response from webhook") + + // ErrWebhookTimeout is returned when the webhook does not respond in time. + ErrWebhookTimeout = errors.New("webhook timeout") +) + +// Options are the options for the webhook client. +type Options struct { + CacheKeyPrefix string + CacheTTL time.Duration + + MaxRetries uint64 + MaxWaitInterval time.Duration +} + +// Client is a client for the webhook. +type Client[Req any, Res any] struct { + cache *cache.LRUExpireCache[string, types.Pair[int, *Res]] + url string + options Options +} + +// NewClient creates a new instance of Client. +func NewClient[Req any, Res any]( + url string, + Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], + options Options, +) *Client[Req, Res] { + return &Client[Req, Res]{ + url: url, + cache: Cache, + options: options, + } +} + +// Send sends the given request to the webhook. +func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, 0, fmt.Errorf("marshal webhook request: %w", err) + } + + cacheKey := c.options.CacheKeyPrefix + ":" + string(body) + if entry, ok := c.cache.Get(cacheKey); ok { + return entry.Second, entry.First, nil + } + + var res Res + status, err := c.withExponentialBackoff(ctx, func() (int, error) { + // TODO(hackerwins, window9u): We should consider using HMAC to sign the request. + resp, err := http.Post( + c.url, + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + return 0, fmt.Errorf("post to webhook: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + logging.From(ctx).Error(err) + } + }() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusUnauthorized && + resp.StatusCode != http.StatusForbidden { + return resp.StatusCode, ErrUnexpectedStatusCode + } + + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return resp.StatusCode, ErrUnexpectedResponse + } + + return resp.StatusCode, nil + }) + if err != nil { + return nil, status, err + } + + // TODO(hackerwins): We should consider caching the response of Unauthorized as well. + if status != http.StatusUnauthorized { + c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL) + } + + return &res, status, nil +} + +func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) { + var retries uint64 + var statusCode int + for retries <= c.options.MaxRetries { + statusCode, err := webhookFn() + if !shouldRetry(statusCode, err) { + if err == ErrUnexpectedStatusCode { + return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) + } + + return statusCode, err + } + + waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval) + + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-time.After(waitBeforeRetry): + } + + retries++ + } + + return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout) +} + +// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds. +func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration { + interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond + if maxWaitInterval < interval { + return maxWaitInterval + } + + return interval +} + +// shouldRetry returns true if the given error should be retried. +// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry +func shouldRetry(statusCode int, err error) bool { + // If the connection is reset, we should retry. + var errno syscall.Errno + if errors.As(err, &errno) { + return errno == syscall.ECONNRESET + } + + return statusCode == http.StatusInternalServerError || + statusCode == http.StatusServiceUnavailable || + statusCode == http.StatusGatewayTimeout || + statusCode == http.StatusTooManyRequests +} diff --git a/server/backend/backend.go b/server/backend/backend.go index de4052829..4bba83fdd 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -29,6 +29,7 @@ import ( "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/cache" + pkgtypes "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/server/backend/background" "github.com/yorkie-team/yorkie/server/backend/database" memdb "github.com/yorkie-team/yorkie/server/backend/database/memory" @@ -43,9 +44,12 @@ import ( // Backend manages Yorkie's backend such as Database and Coordinator. And it // has the server status such as the information of this Server. type Backend struct { - Config *Config - serverInfo *sync.ServerInfo - AuthWebhookCache *cache.LRUExpireCache[string, *types.AuthWebhookResponse] + Config *Config + serverInfo *sync.ServerInfo + WebhookCache *cache.LRUExpireCache[string, pkgtypes.Pair[ + int, + *types.AuthWebhookResponse, + ]] Metrics *prometheus.Metrics DB database.Database @@ -80,8 +84,9 @@ func New( // 02. Create the auth webhook cache. The auth webhook cache is used to // cache the response of the auth webhook. - // TODO(hackerwins): Consider to extend the cache for general purpose. - webhookCache, err := cache.NewLRUExpireCache[string, *types.AuthWebhookResponse](conf.AuthWebhookCacheSize) + webhookCache, err := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( + conf.AuthWebhookCacheSize, + ) if err != nil { return nil, err } @@ -145,9 +150,9 @@ func New( ) return &Backend{ - Config: conf, - serverInfo: serverInfo, - AuthWebhookCache: webhookCache, + Config: conf, + serverInfo: serverInfo, + WebhookCache: webhookCache, Metrics: metrics, DB: db, diff --git a/server/backend/config.go b/server/backend/config.go index 8b5cbc511..a7e194b86 100644 --- a/server/backend/config.go +++ b/server/backend/config.go @@ -67,11 +67,8 @@ type Config struct { // AuthWebhookCacheSize is the cache size of the authorization webhook. AuthWebhookCacheSize int `yaml:"AuthWebhookCacheSize"` - // AuthWebhookCacheAuthTTL is the TTL value to set when caching the authorized result. - AuthWebhookCacheAuthTTL string `yaml:"AuthWebhookCacheAuthTTL"` - - // AuthWebhookCacheUnauthTTL is the TTL value to set when caching the unauthorized result. - AuthWebhookCacheUnauthTTL string `yaml:"AuthWebhookCacheUnauthTTL"` + // AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. + AuthWebhookCacheTTL string `yaml:"AuthWebhookCacheTTL"` // ProjectInfoCacheSize is the cache size of the project info. ProjectInfoCacheSize int `yaml:"ProjectInfoCacheSize"` @@ -104,18 +101,10 @@ func (c *Config) Validate() error { ) } - if _, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL); err != nil { + if _, err := time.ParseDuration(c.AuthWebhookCacheTTL); err != nil { return fmt.Errorf( - `invalid argument "%s" for "--auth-webhook-cache-auth-ttl" flag: %w`, - c.AuthWebhookCacheAuthTTL, - err, - ) - } - - if _, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL); err != nil { - return fmt.Errorf( - `invalid argument "%s" for "--auth-webhook-cache-unauth-ttl" flag: %w`, - c.AuthWebhookCacheUnauthTTL, + `invalid argument "%s" for "--auth-webhook-cache-ttl" flag: %w`, + c.AuthWebhookCacheTTL, err, ) } @@ -153,22 +142,11 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration { return result } -// ParseAuthWebhookCacheAuthTTL returns TTL for authorized cache. -func (c *Config) ParseAuthWebhookCacheAuthTTL() time.Duration { - result, err := time.ParseDuration(c.AuthWebhookCacheAuthTTL) - if err != nil { - fmt.Fprintln(os.Stderr, "parse auth webhook cache auth ttl: %w", err) - os.Exit(1) - } - - return result -} - -// ParseAuthWebhookCacheUnauthTTL returns TTL for unauthorized cache. -func (c *Config) ParseAuthWebhookCacheUnauthTTL() time.Duration { - result, err := time.ParseDuration(c.AuthWebhookCacheUnauthTTL) +// ParseAuthWebhookCacheTTL returns TTL for authorized cache. +func (c *Config) ParseAuthWebhookCacheTTL() time.Duration { + result, err := time.ParseDuration(c.AuthWebhookCacheTTL) if err != nil { - fmt.Fprintln(os.Stderr, "parse auth webhook cache unauth ttl: %w", err) + fmt.Fprintln(os.Stderr, "parse auth webhook cache ttl: %w", err) os.Exit(1) } diff --git a/server/backend/config_test.go b/server/backend/config_test.go index 0a334212c..84e83a2f0 100644 --- a/server/backend/config_test.go +++ b/server/backend/config_test.go @@ -29,8 +29,7 @@ func TestConfig(t *testing.T) { validConf := backend.Config{ ClientDeactivateThreshold: "1h", AuthWebhookMaxWaitInterval: "0ms", - AuthWebhookCacheAuthTTL: "10s", - AuthWebhookCacheUnauthTTL: "10s", + AuthWebhookCacheTTL: "10s", ProjectInfoCacheTTL: "10m", } assert.NoError(t, validConf.Validate()) @@ -44,15 +43,11 @@ func TestConfig(t *testing.T) { assert.Error(t, conf2.Validate()) conf3 := validConf - conf3.AuthWebhookCacheAuthTTL = "s" + conf3.AuthWebhookCacheTTL = "s" assert.Error(t, conf3.Validate()) conf4 := validConf - conf4.AuthWebhookCacheUnauthTTL = "s" + conf4.ProjectInfoCacheTTL = "10 minutes" assert.Error(t, conf4.Validate()) - - conf5 := validConf - conf5.ProjectInfoCacheTTL = "10 minutes" - assert.Error(t, conf5.Validate()) }) } diff --git a/server/config.go b/server/config.go index 425a2d947..6067729de 100644 --- a/server/config.go +++ b/server/config.go @@ -63,8 +63,7 @@ const ( DefaultAuthWebhookMaxRetries = 10 DefaultAuthWebhookMaxWaitInterval = 3000 * time.Millisecond DefaultAuthWebhookCacheSize = 5000 - DefaultAuthWebhookCacheAuthTTL = 10 * time.Second - DefaultAuthWebhookCacheUnauthTTL = 10 * time.Second + DefaultAuthWebhookCacheTTL = 10 * time.Second DefaultProjectInfoCacheSize = 256 DefaultProjectInfoCacheTTL = 10 * time.Minute @@ -186,12 +185,8 @@ func (c *Config) ensureDefaultValue() { c.Backend.AuthWebhookMaxWaitInterval = DefaultAuthWebhookMaxWaitInterval.String() } - if c.Backend.AuthWebhookCacheAuthTTL == "" { - c.Backend.AuthWebhookCacheAuthTTL = DefaultAuthWebhookCacheAuthTTL.String() - } - - if c.Backend.AuthWebhookCacheUnauthTTL == "" { - c.Backend.AuthWebhookCacheUnauthTTL = DefaultAuthWebhookCacheUnauthTTL.String() + if c.Backend.AuthWebhookCacheTTL == "" { + c.Backend.AuthWebhookCacheTTL = DefaultAuthWebhookCacheTTL.String() } if c.Backend.ProjectInfoCacheSize == 0 { diff --git a/server/config.sample.yml b/server/config.sample.yml index 8f8bb81e1..124e6ac2d 100644 --- a/server/config.sample.yml +++ b/server/config.sample.yml @@ -45,7 +45,7 @@ Backend: # If public key is not provided from the client, the default project will be # used. If we are using server as single-tenant mode, this should be set to true. UseDefaultProject: true - + # ClientDeactivateThreshold is deactivate threshold of clients in specific project for housekeeping. ClientDeactivateThreshold: "24h" @@ -63,7 +63,7 @@ Backend: AuthWebhookURL: "" # AuthWebhookMethods is the list of methods to use for authorization. - AuthWebhookMethods: [ ] + AuthWebhookMethods: [] # AuthWebhookMaxRetries is the max count that retries the authorization webhook. AuthWebhookMaxRetries: 10 @@ -71,11 +71,8 @@ Backend: # AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. AuthWebhookMaxWaitInterval: "3s" - # AuthWebhookCacheAuthTTL is the TTL value to set when caching the authorized result. - AuthWebhookCacheAuthTTL: "10s" - - # AuthWebhookCacheUnauthTTL is the TTL value to set when caching the unauthorized result. - AuthWebhookCacheUnauthTTL: "10s" + # AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. + AuthWebhookCacheTTL: "10s" # ProjectInfoCacheSize is the size of the project info cache. ProjectInfoCacheSize: 256 diff --git a/server/config_test.go b/server/config_test.go index 43d2103a7..ed5604aeb 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -70,13 +70,9 @@ func TestNewConfigFromFile(t *testing.T) { assert.NoError(t, err) assert.Equal(t, authWebhookMaxWaitInterval, server.DefaultAuthWebhookMaxWaitInterval) - authWebhookCacheAuthTTL, err := time.ParseDuration(conf.Backend.AuthWebhookCacheAuthTTL) + authWebhookCacheTTL, err := time.ParseDuration(conf.Backend.AuthWebhookCacheTTL) assert.NoError(t, err) - assert.Equal(t, authWebhookCacheAuthTTL, server.DefaultAuthWebhookCacheAuthTTL) - - authWebhookCacheUnauthTTL, err := time.ParseDuration(conf.Backend.AuthWebhookCacheUnauthTTL) - assert.NoError(t, err) - assert.Equal(t, authWebhookCacheUnauthTTL, server.DefaultAuthWebhookCacheUnauthTTL) + assert.Equal(t, authWebhookCacheTTL, server.DefaultAuthWebhookCacheTTL) projectInfoCacheTTL, err := time.ParseDuration(conf.Backend.ProjectInfoCacheTTL) assert.NoError(t, err) diff --git a/server/rpc/auth/auth.go b/server/rpc/auth/auth.go index e6f86c7cb..c56ebb0c5 100644 --- a/server/rpc/auth/auth.go +++ b/server/rpc/auth/auth.go @@ -45,16 +45,16 @@ func AccessAttributes(pack *change.Pack) []types.AccessAttribute { // VerifyAccess verifies the given access. func VerifyAccess(ctx context.Context, be *backend.Backend, accessInfo *types.AccessInfo) error { md := metadata.From(ctx) - project := projects.From(ctx) + prj := projects.From(ctx) - if !project.RequireAuth(accessInfo.Method) { + if !prj.RequireAuth(accessInfo.Method) { return nil } return verifyAccess( ctx, be, - project.AuthWebhookURL, + prj, md.Authorization, accessInfo, ) diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index ec997c316..218390f78 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -17,20 +17,15 @@ package auth import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "math" "net/http" - "syscall" - "time" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/internal/metaerrors" + "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend" - "github.com/yorkie-team/yorkie/server/logging" ) var ( @@ -39,146 +34,51 @@ var ( // ErrPermissionDenied is returned when the given user is not allowed for the access. ErrPermissionDenied = errors.New("method is not allowed for this user") - - // ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook. - ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook") - - // ErrUnexpectedResponse is returned when the response from the webhook is not as expected. - ErrUnexpectedResponse = errors.New("unexpected response from webhook") - - // ErrWebhookTimeout is returned when the webhook does not respond in time. - ErrWebhookTimeout = errors.New("webhook timeout") ) // verifyAccess verifies the given user is allowed to access the given method. func verifyAccess( ctx context.Context, be *backend.Backend, - authWebhookURL string, + prj *types.Project, token string, accessInfo *types.AccessInfo, ) error { - reqBody, err := json.Marshal(types.AuthWebhookRequest{ + cli := webhook.NewClient[types.AuthWebhookRequest]( + prj.AuthWebhookURL, + be.WebhookCache, + webhook.Options{ + CacheKeyPrefix: prj.PublicKey + ":auth", + CacheTTL: be.Config.ParseAuthWebhookCacheTTL(), + MaxRetries: be.Config.AuthWebhookMaxRetries, + MaxWaitInterval: be.Config.ParseAuthWebhookMaxWaitInterval(), + }, + ) + + res, status, err := cli.Send(ctx, types.AuthWebhookRequest{ Token: token, Method: accessInfo.Method, Attributes: accessInfo.Attributes, }) if err != nil { - return fmt.Errorf("marshal auth webhook request: %w", err) + return fmt.Errorf("send to webhook: %w", err) } - cacheKey := string(reqBody) - if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok { - resp := entry - if !resp.Allowed { - return fmt.Errorf("%s: %w", resp.Reason, ErrPermissionDenied) - } + if status == http.StatusOK && res.Allowed { return nil } - - var authResp *types.AuthWebhookResponse - if err := withExponentialBackoff(ctx, be.Config, func() (int, error) { - resp, err := http.Post( - authWebhookURL, - "application/json", - bytes.NewBuffer(reqBody), + if status == http.StatusForbidden && !res.Allowed { + return metaerrors.New( + ErrPermissionDenied, + map[string]string{"reason": res.Reason}, ) - if err != nil { - return 0, fmt.Errorf("post to webhook: %w", err) - } - - defer func() { - if err := resp.Body.Close(); err != nil { - logging.From(ctx).Error(err) - } - }() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusUnauthorized && - resp.StatusCode != http.StatusForbidden { - return resp.StatusCode, ErrUnexpectedStatusCode - } - - authResp, err = types.NewAuthWebhookResponse(resp.Body) - if err != nil { - return resp.StatusCode, err - } - - if resp.StatusCode == http.StatusOK && authResp.Allowed { - return resp.StatusCode, nil - } - if resp.StatusCode == http.StatusForbidden && !authResp.Allowed { - return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrPermissionDenied) - } - if resp.StatusCode == http.StatusUnauthorized && !authResp.Allowed { - return resp.StatusCode, metaerrors.New( - ErrUnauthenticated, - map[string]string{"reason": authResp.Reason}, - ) - } - - return resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedResponse) - }); err != nil { - if errors.Is(err, ErrPermissionDenied) { - be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL()) - } - - return err } - - be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheAuthTTL()) - - return nil -} - -func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn func() (int, error)) error { - var retries uint64 - var statusCode int - for retries <= cfg.AuthWebhookMaxRetries { - statusCode, err := webhookFn() - if !shouldRetry(statusCode, err) { - if err == ErrUnexpectedStatusCode { - return fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) - } - - return err - } - - waitBeforeRetry := waitInterval(retries, cfg.ParseAuthWebhookMaxWaitInterval()) - - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(waitBeforeRetry): - } - - retries++ - } - - return fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout) -} - -// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds. -func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration { - interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond - if maxWaitInterval < interval { - return maxWaitInterval - } - - return interval -} - -// shouldRetry returns true if the given error should be retried. -// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry -func shouldRetry(statusCode int, err error) bool { - // If the connection is reset, we should retry. - var errno syscall.Errno - if errors.As(err, &errno) { - return errno == syscall.ECONNRESET + if status == http.StatusUnauthorized && !res.Allowed { + return metaerrors.New( + ErrUnauthenticated, + map[string]string{"reason": res.Reason}, + ) } - return statusCode == http.StatusInternalServerError || - statusCode == http.StatusServiceUnavailable || - statusCode == http.StatusGatewayTimeout || - statusCode == http.StatusTooManyRequests + return fmt.Errorf("%d: %w", status, webhook.ErrUnexpectedResponse) } diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index b56fd161c..37b865546 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -30,6 +30,7 @@ import ( "github.com/yorkie-team/yorkie/internal/validation" "github.com/yorkie-team/yorkie/pkg/document/key" "github.com/yorkie-team/yorkie/pkg/document/time" + "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend/database" "github.com/yorkie-team/yorkie/server/clients" "github.com/yorkie-team/yorkie/server/documents" @@ -83,9 +84,9 @@ var errorToConnectCode = map[error]connect.Code{ database.ErrMismatchedPassword: connect.CodeUnauthenticated, // Internal means an internal error occurred. - auth.ErrUnexpectedStatusCode: connect.CodeInternal, - auth.ErrUnexpectedResponse: connect.CodeInternal, - auth.ErrWebhookTimeout: connect.CodeInternal, + webhook.ErrUnexpectedStatusCode: connect.CodeInternal, + webhook.ErrUnexpectedResponse: connect.CodeInternal, + webhook.ErrWebhookTimeout: connect.CodeInternal, // PermissionDenied means the request does not have permission for the operation. auth.ErrPermissionDenied: connect.CodePermissionDenied, @@ -131,12 +132,12 @@ var errorToCode = map[error]string{ converter.ErrUnsupportedValueType: "ErrUnsupportedValueType", converter.ErrUnsupportedCounterType: "ErrUnsupportedCounterType", - auth.ErrPermissionDenied: "ErrPermissionDenied", - auth.ErrUnauthenticated: "ErrUnauthenticated", - auth.ErrUnexpectedResponse: "ErrUnexpectedResponse", - auth.ErrUnexpectedStatusCode: "ErrUnexpectedStatusCode", - auth.ErrWebhookTimeout: "ErrWebhookTimeout", - database.ErrMismatchedPassword: "ErrMismatchedPassword", + auth.ErrPermissionDenied: "ErrPermissionDenied", + auth.ErrUnauthenticated: "ErrUnauthenticated", + webhook.ErrUnexpectedResponse: "ErrUnexpectedResponse", + webhook.ErrUnexpectedStatusCode: "ErrUnexpectedStatusCode", + webhook.ErrWebhookTimeout: "ErrWebhookTimeout", + database.ErrMismatchedPassword: "ErrMismatchedPassword", } // CodeOf returns the string representation of the given error. diff --git a/test/helper/helper.go b/test/helper/helper.go index 086cfd13e..708bc6248 100644 --- a/test/helper/helper.go +++ b/test/helper/helper.go @@ -76,8 +76,7 @@ var ( SnapshotWithPurgingChanges = false AuthWebhookMaxWaitInterval = 3 * gotime.Millisecond AuthWebhookSize = 100 - AuthWebhookCacheAuthTTL = 10 * gotime.Second - AuthWebhookCacheUnauthTTL = 10 * gotime.Second + AuthWebhookCacheTTL = 10 * gotime.Second ProjectInfoCacheSize = 256 ProjectInfoCacheTTL = 5 * gotime.Second @@ -266,8 +265,7 @@ func TestConfig() *server.Config { SnapshotWithPurgingChanges: SnapshotWithPurgingChanges, AuthWebhookMaxWaitInterval: AuthWebhookMaxWaitInterval.String(), AuthWebhookCacheSize: AuthWebhookSize, - AuthWebhookCacheAuthTTL: AuthWebhookCacheAuthTTL.String(), - AuthWebhookCacheUnauthTTL: AuthWebhookCacheUnauthTTL.String(), + AuthWebhookCacheTTL: AuthWebhookCacheTTL.String(), ProjectInfoCacheSize: ProjectInfoCacheSize, ProjectInfoCacheTTL: ProjectInfoCacheTTL.String(), GatewayAddr: fmt.Sprintf("localhost:%d", RPCPort+portOffset), diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 05c7ddf28..baf1e0bf9 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -35,6 +35,7 @@ import ( "github.com/yorkie-team/yorkie/pkg/document" "github.com/yorkie-team/yorkie/pkg/document/json" "github.com/yorkie-team/yorkie/pkg/document/presence" + "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server" "github.com/yorkie-team/yorkie/server/rpc/auth" "github.com/yorkie-team/yorkie/server/rpc/connecthelper" @@ -307,7 +308,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) - assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedStatusCode), converter.ErrorCodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(webhook.ErrUnexpectedStatusCode), converter.ErrorCodeOf(err)) }) t.Run("unexpected webhook response test", func(t *testing.T) { @@ -348,7 +349,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) - assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedResponse), converter.ErrorCodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(webhook.ErrUnexpectedResponse), converter.ErrorCodeOf(err)) }) t.Run("unavailable authentication server test(timeout)", func(t *testing.T) { @@ -378,7 +379,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { err = cli.Activate(ctx) assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) - assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(webhook.ErrWebhookTimeout), converter.ErrorCodeOf(err)) }) t.Run("successful authorization after temporarily unavailable server test", func(t *testing.T) { @@ -434,9 +435,9 @@ func TestAuthWebhookCache(t *testing.T) { } })) - authorizedTTL := 1 * time.Second + authTTL := 1 * time.Second conf := helper.TestConfig() - conf.Backend.AuthWebhookCacheAuthTTL = authorizedTTL.String() + conf.Backend.AuthWebhookCacheTTL = authTTL.String() svr, err := server.New(conf) assert.NoError(t, err) @@ -483,7 +484,7 @@ func TestAuthWebhookCache(t *testing.T) { } // 02. multiple requests to update the document after eviction by ttl. - time.Sleep(authorizedTTL) + time.Sleep(authTTL) for i := 0; i < 3; i++ { assert.NoError(t, doc.Update(func(root *json.Object, p *presence.Presence) error { root.SetNewObject("k1") @@ -512,9 +513,9 @@ func TestAuthWebhookCache(t *testing.T) { reqCnt++ })) - unauthorizedTTL := 1 * time.Second + authTTL := 1 * time.Second conf := helper.TestConfig() - conf.Backend.AuthWebhookCacheUnauthTTL = unauthorizedTTL.String() + conf.Backend.AuthWebhookCacheTTL = authTTL.String() svr, err := server.New(conf) assert.NoError(t, err) @@ -551,7 +552,7 @@ func TestAuthWebhookCache(t *testing.T) { } // 02. multiple requests after eviction by ttl. - time.Sleep(unauthorizedTTL) + time.Sleep(authTTL) for i := 0; i < 3; i++ { err = cli.Activate(ctx) assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) @@ -576,9 +577,9 @@ func TestAuthWebhookCache(t *testing.T) { reqCnt++ })) - unauthorizedTTL := 1 * time.Second + authTTL := 1 * time.Second conf := helper.TestConfig() - conf.Backend.AuthWebhookCacheUnauthTTL = unauthorizedTTL.String() + conf.Backend.AuthWebhookCacheTTL = authTTL.String() svr, err := server.New(conf) assert.NoError(t, err) @@ -615,7 +616,7 @@ func TestAuthWebhookCache(t *testing.T) { } // 02. multiple requests after eviction by ttl. - time.Sleep(unauthorizedTTL) + time.Sleep(authTTL) for i := 0; i < 3; i++ { err = cli.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))