Skip to content

Commit

Permalink
feat: support HTTP 429 with Retry-After (#194)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcin Rataj <[email protected]>
  • Loading branch information
hacdias and lidel authored Mar 6, 2023
1 parent 7af2b2b commit 15f2131
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 104 deletions.
178 changes: 178 additions & 0 deletions gateway/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package gateway

import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"time"

ipld "github.com/ipfs/go-ipld-format"
"github.com/ipfs/go-path/resolver"
)

var (
ErrInternalServerError = NewErrorResponseForCode(http.StatusInternalServerError)
ErrGatewayTimeout = NewErrorResponseForCode(http.StatusGatewayTimeout)
ErrBadGateway = NewErrorResponseForCode(http.StatusBadGateway)
ErrServiceUnavailable = NewErrorResponseForCode(http.StatusServiceUnavailable)
ErrTooManyRequests = NewErrorResponseForCode(http.StatusTooManyRequests)
)

type ErrorRetryAfter struct {
Err error
RetryAfter time.Duration
}

// NewErrorWithRetryAfter wraps any error in RetryAfter hint that
// gets passed to HTTP clients in Retry-After HTTP header.
func NewErrorRetryAfter(err error, retryAfter time.Duration) *ErrorRetryAfter {
if err == nil {
err = ErrServiceUnavailable
}
if retryAfter < 0 {
retryAfter = 0
}
return &ErrorRetryAfter{
RetryAfter: retryAfter,
Err: err,
}
}

func (e *ErrorRetryAfter) Error() string {
var text string
if e.Err != nil {
text = e.Err.Error()
}
if e.RetryAfter != 0 {
text += fmt.Sprintf(", retry after %s", e.Humanized())
}
return text
}

func (e *ErrorRetryAfter) Unwrap() error {
return e.Err
}

func (e *ErrorRetryAfter) Is(err error) bool {
switch err.(type) {
case *ErrorRetryAfter:
return true
default:
return false
}
}

func (e *ErrorRetryAfter) RoundSeconds() time.Duration {
return e.RetryAfter.Round(time.Second)
}

func (e *ErrorRetryAfter) Humanized() string {
return e.RoundSeconds().String()
}

// HTTPHeaderValue returns the Retry-After header value as a string, representing the number
// of seconds to wait before making a new request, rounded to the nearest second.
// This function follows the Retry-After header definition as specified in RFC 9110.
func (e *ErrorRetryAfter) HTTPHeaderValue() string {
return strconv.Itoa(int(e.RoundSeconds().Seconds()))
}

// Custom type for collecting error details to be handled by `webError`. When an error
// of this type is returned to the gateway handler, the StatusCode will be used for
// the response status.
type ErrorResponse struct {
StatusCode int
Err error
}

func NewErrorResponseForCode(statusCode int) *ErrorResponse {
return NewErrorResponse(errors.New(http.StatusText(statusCode)), statusCode)
}

func NewErrorResponse(err error, statusCode int) *ErrorResponse {
return &ErrorResponse{
Err: err,
StatusCode: statusCode,
}
}

func (e *ErrorResponse) Is(err error) bool {
switch err.(type) {
case *ErrorResponse:
return true
default:
return false
}
}

func (e *ErrorResponse) Error() string {
var text string
if e.Err != nil {
text = e.Err.Error()
}
return text
}

func (e *ErrorResponse) Unwrap() error {
return e.Err
}

func webError(w http.ResponseWriter, err error, defaultCode int) {
code := defaultCode

// Pass Retry-After hint to the client
var era *ErrorRetryAfter
if errors.As(err, &era) {
if era.RetryAfter > 0 {
w.Header().Set("Retry-After", era.HTTPHeaderValue())
// Adjust defaultCode if needed
if code != http.StatusTooManyRequests && code != http.StatusServiceUnavailable {
code = http.StatusTooManyRequests
}
}
err = era.Unwrap()
}

// Handle status code
switch {
case isErrNotFound(err):
code = http.StatusNotFound
case errors.Is(err, context.DeadlineExceeded):
code = http.StatusGatewayTimeout
}

// Handle explicit code in ErrorResponse
var gwErr *ErrorResponse
if errors.As(err, &gwErr) {
code = gwErr.StatusCode
}

http.Error(w, err.Error(), code)
}

func isErrNotFound(err error) bool {
if ipld.IsNotFound(err) {
return true
}

// Checks if err is a resolver.ErrNoLink. resolver.ErrNoLink does not implement
// the .Is interface and cannot be directly compared to. Therefore, errors.Is
// always returns false with it.
for {
_, ok := err.(resolver.ErrNoLink)
if ok {
return true
}

err = errors.Unwrap(err)
if err == nil {
return false
}
}
}

func webRequestError(w http.ResponseWriter, err *ErrorResponse) {
webError(w, err.Err, err.StatusCode)
}
65 changes: 65 additions & 0 deletions gateway/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package gateway

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestErrRetryAfterIs(t *testing.T) {
var err error

err = NewErrorRetryAfter(errors.New("test"), 10*time.Second)
assert.True(t, errors.Is(err, &ErrorRetryAfter{}), "pointer to error must be error")

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.Is(err, &ErrorRetryAfter{}), "wrapped pointer to error must be error")
}

func TestErrRetryAfterAs(t *testing.T) {
var (
err error
errRA *ErrorRetryAfter
)

err = NewErrorRetryAfter(errors.New("test"), 25*time.Second)
assert.True(t, errors.As(err, &errRA), "pointer to error must be error")
assert.EqualValues(t, errRA.RetryAfter, 25*time.Second)

err = fmt.Errorf("wrapped: %w", err)
assert.True(t, errors.As(err, &errRA), "wrapped pointer to error must be error")
assert.EqualValues(t, errRA.RetryAfter, 25*time.Second)
}

func TestWebError(t *testing.T) {
t.Parallel()

t.Run("429 Too Many Requests", func(t *testing.T) {
err := fmt.Errorf("wrapped for testing: %w", NewErrorRetryAfter(ErrTooManyRequests, 0))
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Zero(t, len(w.Result().Header.Values("Retry-After")))
})

t.Run("429 Too Many Requests with Retry-After header", func(t *testing.T) {
err := NewErrorRetryAfter(ErrTooManyRequests, 25*time.Second)
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode)
assert.Equal(t, "25", w.Result().Header.Get("Retry-After"))
})

t.Run("503 Service Unavailable with Retry-After header", func(t *testing.T) {
err := NewErrorRetryAfter(ErrServiceUnavailable, 50*time.Second)
w := httptest.NewRecorder()
webError(w, err, http.StatusInternalServerError)
assert.Equal(t, http.StatusServiceUnavailable, w.Result().StatusCode)
assert.Equal(t, "50", w.Result().Header.Get("Retry-After"))
})
}
Loading

0 comments on commit 15f2131

Please sign in to comment.