Skip to content

Commit

Permalink
fix(httpauth): Correctly handle concurrent requests on server (#3111)
Browse files Browse the repository at this point in the history
Co-authored-by: Adin Schmahmann <[email protected]>
  • Loading branch information
MarcoPolo and aschmahmann authored Dec 18, 2024
1 parent b07e3aa commit a0a498e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 16 deletions.
66 changes: 55 additions & 11 deletions p2p/http/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package httppeeridauth

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"hash"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -171,14 +169,12 @@ func TestMutualAuth(t *testing.T) {

t.Run("Tokens Invalidated", func(t *testing.T) {
// Clear the auth token on the server side
server.Hmac = func() hash.Hash {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
return hmac.New(sha256.New, key)
}()
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
server.hmacPool = newHmacPool(key)

req, err := http.NewRequest("POST", ts.URL, nil)
req.GetBody = func() (io.ReadCloser, error) {
Expand Down Expand Up @@ -241,3 +237,51 @@ func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Respons
func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config {
return irt.RoundTripper.(*http.Transport).TLSClientConfig
}

func TestConcurrentAuth(t *testing.T) {
serverKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)

auth := ServerPeerIDAuth{
PrivKey: serverKey,
ValidHostnameFn: func(s string) bool {
return s == "example.com"
},
TokenTTL: time.Hour,
NoTLS: true,
Next: func(peer peer.ID, w http.ResponseWriter, r *http.Request) {
reqBody, err := io.ReadAll(r.Body)
require.NoError(t, err)
_, err = w.Write(reqBody)
require.NoError(t, err)
},
}

ts := httptest.NewServer(&auth)
t.Cleanup(ts.Close)

wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)

clientAuth := ClientPeerIDAuth{PrivKey: clientKey}
reqBody := []byte(fmt.Sprintf("echo %d", i))
req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(reqBody))
require.NoError(t, err)
req.Host = "example.com"

client := ts.Client()
_, resp, err := clientAuth.AuthenticatedDo(client, req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, reqBody, respBody)
}()
}
wg.Wait()
}
39 changes: 34 additions & 5 deletions p2p/http/auth/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,30 @@ import (
"github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake"
)

type hmacPool struct {
p sync.Pool
}

func newHmacPool(key []byte) *hmacPool {
return &hmacPool{
p: sync.Pool{
New: func() any {
return hmac.New(sha256.New, key)
},
},
}
}

func (p *hmacPool) Get() hash.Hash {
h := p.p.Get().(hash.Hash)
h.Reset()
return h
}

func (p *hmacPool) Put(h hash.Hash) {
p.p.Put(h)
}

type ServerPeerIDAuth struct {
PrivKey crypto.PrivKey
TokenTTL time.Duration
Expand All @@ -26,8 +50,9 @@ type ServerPeerIDAuth struct {
// which the Host header returns true.
ValidHostnameFn func(hostname string) bool

Hmac hash.Hash
HmacKey []byte
initHmac sync.Once
hmacPool *hmacPool
}

// ServeHTTP implements the http.Handler interface for PeerIDAuth. It will
Expand All @@ -36,14 +61,15 @@ type ServerPeerIDAuth struct {
// requests.
func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
a.initHmac.Do(func() {
if a.Hmac == nil {
if a.HmacKey == nil {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
panic(err)
}
a.Hmac = hmac.New(sha256.New, key)
a.HmacKey = key
}
a.hmacPool = newHmacPool(a.HmacKey)
})

hostname := r.Host
Expand Down Expand Up @@ -76,11 +102,13 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

hmac := a.hmacPool.Get()
defer a.hmacPool.Put(hmac)
hs := handshake.PeerIDAuthHandshakeServer{
Hostname: hostname,
PrivKey: a.PrivKey,
TokenTTL: a.TokenTTL,
Hmac: a.Hmac,
Hmac: hmac,
}
err := hs.ParseHeaderVal([]byte(r.Header.Get("Authorization")))
if err != nil {
Expand All @@ -95,11 +123,12 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
errors.Is(err, handshake.ErrExpiredChallenge),
errors.Is(err, handshake.ErrExpiredToken):

hmac.Reset()
hs := handshake.PeerIDAuthHandshakeServer{
Hostname: hostname,
PrivKey: a.PrivKey,
TokenTTL: a.TokenTTL,
Hmac: a.Hmac,
Hmac: hmac,
}
hs.Run()
hs.SetHeader(w.Header())
Expand Down

0 comments on commit a0a498e

Please sign in to comment.