Skip to content

Commit

Permalink
feat: add maxConnections and maxRequestBodySize limit setting. (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyuguen authored Mar 13, 2023
1 parent 0529d03 commit 77938e9
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 10 deletions.
16 changes: 10 additions & 6 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ type IPFSConfig struct {
}

type APIConfig struct {
ListenAddr string `mapstructure:"listen-addr"`
WriteTimeout int64 `mapstructure:"write-timeout"`
ReadTimeout int64 `mapstructure:"read-timeout"`
ListenAddr string `mapstructure:"listen-addr"`
WriteTimeout int64 `mapstructure:"write-timeout"`
ReadTimeout int64 `mapstructure:"read-timeout"`
MaxConnections int `mapstructure:"max-connections"`
MaxRequestBodySize int64 `mapstructure:"max-request-body-size"`
}

func DefaultConfig() *Config {
Expand Down Expand Up @@ -81,9 +83,11 @@ func DefaultConfig() *Config {
IPFSNodeAddr: "127.0.0.1:5001",
},
API: APIConfig{
ListenAddr: "127.0.0.1:8080",
WriteTimeout: 60,
ReadTimeout: 15,
ListenAddr: "127.0.0.1:8080",
WriteTimeout: 60,
ReadTimeout: 15,
MaxConnections: 50,
MaxRequestBodySize: 4 << (10 * 2), // 4MB
},
}
}
Expand Down
2 changes: 2 additions & 0 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ ipfs-node-addr = "{{ .IPFS.IPFSNodeAddr }}"
listen-addr = "{{ .API.ListenAddr }}"
write-timeout = "{{ .API.WriteTimeout }}"
read-timeout = "{{ .API.ReadTimeout }}"
max-connections = "{{ .API.MaxConnections }}"
max-request-body-size = "{{ .API.MaxRequestBodySize }}"
`

var configTemplate *template.Template
Expand Down
31 changes: 31 additions & 0 deletions server/middleware/limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package middleware

import (
"net/http"
)

type limitMiddleware struct {
maxRequestBodySize int64
}

func NewLimitMiddleware(maxRequestBodySize int64) *limitMiddleware {
return &limitMiddleware{
maxRequestBodySize,
}
}

// Middleware limits the request body size.
// This is done by first constraining to the ContentLength of the request headder,
// and then reading the actual Body to constraint it.
func (mw *limitMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > mw.maxRequestBodySize {
http.Error(w, "request body too large", http.StatusBadRequest)
return
}
r.Body = http.MaxBytesReader(w, r.Body, mw.maxRequestBodySize)
defer r.Body.Close()

next.ServeHTTP(w, r)
})
}
102 changes: 102 additions & 0 deletions server/middleware/limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package middleware_test

import (
"bytes"
"crypto/rand"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/medibloc/panacea-oracle/server/middleware"
"github.com/stretchr/testify/require"
)

func TestBodySizeSmallerThanLimitSetting(t *testing.T) {
testLimitMiddlewareHTTPRequest(
t,
newRequest(newRandomBody(1023)),
1024,
http.StatusOK,
"",
)
}

func TestBodySizeSameLimitSetting(t *testing.T) {
testLimitMiddlewareHTTPRequest(
t,
newRequest(newRandomBody(1024)),
1024,
http.StatusOK,
"",
)
}

func TestBodySizeLargeThanLimitSetting(t *testing.T) {
testLimitMiddlewareHTTPRequest(
t,
newRequest(newRandomBody(1025)),
1024,
http.StatusBadRequest,
"request body too large",
)
}

func TestDifferentBodySizeAndHeaderContentSize(t *testing.T) {
req := newRequest(newRandomBody(1025))
req.ContentLength = 1024

testLimitMiddlewareHTTPRequest(
t,
req,
1024,
http.StatusBadRequest,
"request body too large",
)
}

func newRandomBody(size int) []byte {
body := make([]byte, size)
if _, err := rand.Read(body); err != nil {
panic(err)
}

return body
}

func newRequest(body []byte) *http.Request {
return httptest.NewRequest(
"GET",
"http://test.com",
bytes.NewBuffer(body),
)
}

func testLimitMiddlewareHTTPRequest(
t *testing.T,
req *http.Request,
maxRequestBodySize int64,
statusCode int,
bodyMsg string,
) {
w := httptest.NewRecorder()
mw := middleware.NewLimitMiddleware(maxRequestBodySize)
testHandler := mw.Middleware(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
}
}),
)

testHandler.ServeHTTP(w, req)

resp := w.Result()
require.Equal(t, statusCode, resp.StatusCode)
if bodyMsg != "" {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), bodyMsg)
}
}
90 changes: 90 additions & 0 deletions server/netutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package server_test

import (
"io"
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"golang.org/x/net/netutil"
)

func TestNetUtil(t *testing.T) {

lis := &fakeListener{timeWait: 1}

limitLis := netutil.LimitListener(lis, 2)

wg := &sync.WaitGroup{}
start := time.Now()
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()

conn, err := limitLis.Accept()
require.NoError(t, err)
defer conn.Close()
}(i)
}

wg.Wait()
end := time.Now()

// Send 10 requests, process 2 at a time, and take 1 second per request.
// This request test should take 5 to 6 seconds.
require.True(t, start.Add(time.Second*5).Before(end))
require.True(t, start.Add(time.Second*6).After(end))

}

type fakeListener struct {
timeWait int64
}

// Accept waits for and returns the next connection to the listener.
func (l *fakeListener) Accept() (net.Conn, error) {
time.Sleep(time.Second * time.Duration(l.timeWait))

return fakeNetConn{}, nil
}

// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *fakeListener) Close() error {
return nil
}

// Addr returns the listener's network address.
func (l *fakeListener) Addr() net.Addr {
return fakeAddr(1)
}

type fakeNetConn struct {
io.Reader
io.Writer
}

func (c fakeNetConn) Close() error { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }

type fakeAddr int

var (
localAddr = fakeAddr(1)
remoteAddr = fakeAddr(2)
)

func (a fakeAddr) Network() string {
return "net"
}

func (a fakeAddr) String() string {
return "str"
}
24 changes: 20 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"net"
"net/http"
"time"

Expand All @@ -11,19 +12,24 @@ import (
"github.com/medibloc/panacea-oracle/server/service/status"
"github.com/medibloc/panacea-oracle/service"
log "github.com/sirupsen/logrus"
"golang.org/x/net/netutil"
)

type Server struct {
*http.Server
maxConnections int
}

func New(svc service.Service) *Server {
router := mux.NewRouter()
conf := svc.Config()

limitMiddleware := middleware.NewLimitMiddleware(conf.API.MaxRequestBodySize)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(svc.QueryClient())

dealRouter := router.PathPrefix("/v0/data-deal").Subrouter()
dealRouter.Use(
limitMiddleware.Middleware,
jwtAuthMiddleware.Middleware,
)

Expand All @@ -35,14 +41,24 @@ func New(svc service.Service) *Server {
return &Server{
&http.Server{
Handler: router,
Addr: svc.Config().API.ListenAddr,
WriteTimeout: time.Duration(svc.Config().API.WriteTimeout) * time.Second,
ReadTimeout: time.Duration(svc.Config().API.ReadTimeout) * time.Second,
Addr: conf.API.ListenAddr,
WriteTimeout: time.Duration(conf.API.WriteTimeout) * time.Second,
ReadTimeout: time.Duration(conf.API.ReadTimeout) * time.Second,
},
conf.API.MaxConnections,
}
}

func (srv *Server) Run() error {
addr := srv.Addr
if addr == "" {
addr = ":http"
}
lis, err := net.Listen("tcp", addr)
if err != nil {
return err
}

log.Infof("HTTP server is started: %s", srv.Addr)
return srv.ListenAndServe()
return srv.Serve(netutil.LimitListener(lis, srv.maxConnections))
}

0 comments on commit 77938e9

Please sign in to comment.