From be611d216cd47a021fd8f9096985f06faee936c0 Mon Sep 17 00:00:00 2001 From: yury Date: Tue, 7 Aug 2018 17:07:21 +0300 Subject: [PATCH] Implement authentication This commit adds authenticator interface with 3 implementations: - no password (instant authentication) - clear text password - md5 hashed password One of these authenticators can be selected in sess.go --- auth.go | 164 +++++++++++++++++++++++++++++++++++++++++++++-- auth_test.go | 177 ++++++++++++++++++++++++++++++++++++++++++++++++++- sess.go | 8 ++- 3 files changed, 342 insertions(+), 7 deletions(-) diff --git a/auth.go b/auth.go index 514a384..3fb067a 100644 --- a/auth.go +++ b/auth.go @@ -1,11 +1,167 @@ package pgsrv +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "fmt" +) + +// authenticator interface defines objects able to perform user authentication +// that happens at the very beginning of every session. type authenticator interface { - authenticate() msg + authenticate() (msg, error) +} + +// authenticationNoPassword responds with auth OK immediately. +type authenticationNoPassword struct{} + +func (*authenticationNoPassword) authenticate() (msg, error) { + return msg{'R', 0, 0, 0, 8, 0, 0, 0, 0}, nil +} + +// messageReadWriter describes objects that handle client-server communication. +// Objects implementing this interface are used to send password requests to users, +// and receive their responses. +type messageReadWriter interface { + Write(m msg) error + Read() (msg, error) +} + +// passwordProvider describes objects that are able to provide a password given a user name. +type passwordProvider interface { + getPassword(user string) ([]byte, error) +} + +// constantPasswordProvider is a password provider that always returns the same password, +// which it is given during the initialization. +type constantPasswordProvider struct { + password []byte +} + +func (cpp *constantPasswordProvider) getPassword(user string) ([]byte, error) { + return cpp.password, nil +} + +// authenticationClearText is an authenticator that requests and accepts a clear text password +// from the client. It is not recommended to use it for security reasons. +// +// It requires a messageReadWriter implementation to communicate with the client, +// passwordProvider implementation to verify that the provided password is correct, +// and a map of arguments that were sent at the beginning of the session (user, database, etc) +type authenticationClearText struct { + rw messageReadWriter + args map[string]interface{} + pp passwordProvider +} + +func (a *authenticationClearText) authenticate() (msg, error) { + // AuthenticationClearText + passwordRequest := msg{ + 'R', + 0, 0, 0, 8, + 0, 0, 0, 3, + } + + err := a.rw.Write(passwordRequest) + if err != nil { + return msg{}, err + } + + m, err := a.rw.Read() + if err != nil { + return msg{}, err + } + + if m.Type() != 'p' { + return msg{}, + fmt.Errorf("expected password response, got message type %c", m.Type()) + } + + user := a.args["user"].(string) + expectedPassword, err := a.pp.getPassword(user) + actualPassword := extractPassword(m) + + if !bytes.Equal(expectedPassword, actualPassword) { + return msg{}, + fmt.Errorf("Password does not match for user \"%s\"", user) + } + + return authOKMsg(), nil +} + +// authenticationMD5 is an authenticator that requests and accepts an MD5 hashed password +// from the client. +// +// It requires a messageReadWriter implementation to communicate with the client, +// passwordProvider implementation to verify that the provided password is correct, +// and a map of arguments that were sent at the beginning of the session (user, database, etc) +type authenticationMD5 struct { + rw messageReadWriter + args map[string]interface{} + pp passwordProvider +} + +func (a *authenticationMD5) authenticate() (msg, error) { + // AuthenticationMD5Password + passwordRequest := msg{ + 'R', + 0, 0, 0, 12, + 0, 0, 0, 5, + } + salt := getRandomSalt() + passwordRequest = append(passwordRequest, salt...) + + err := a.rw.Write(passwordRequest) + if err != nil { + return msg{}, err + } + + m, err := a.rw.Read() + if err != nil { + return msg{}, err + } + + if m.Type() != 'p' { + return msg{}, + fmt.Errorf("expected password response, got message type %c", m.Type()) + } + + user := a.args["user"].(string) + expectedPassword, err := a.pp.getPassword(user) + expectedHash := hashUserPassword(user, expectedPassword, salt) + + actualHash := extractPassword(m) + + if !bytes.Equal(expectedHash, actualHash) { + return msg{}, + fmt.Errorf("Password does not match for user \"%s\"", user) + } + + return authOKMsg(), nil +} + +// getRandomSalt returns a cryptographically secure random slice of 4 bytes. +func getRandomSalt() []byte { + salt := make([]byte, 4) + rand.Read(salt) + return salt } -type noPassword struct{} +// extractPassword extracts the password from a provided 'p' message. +// It assumes that the message is valid. +func extractPassword(m msg) []byte { + // password starts after the size (4 bytes) and lasts until null-terminator + return m[5 : len(m)-1] +} -func (*noPassword) authenticate() msg { - return msg{'R', 0, 0, 0, 8, 0, 0, 0, 0} +// hashUserPassword hashes the provided username and password with the provided salt +// using the same MD5 hashing technique as postgresql MD5 authentication +func hashUserPassword(user string, password, salt []byte) []byte { + // concat('md5', md5(concat(md5(concat(password, username)), random-salt))) + pu := append(password, []byte(user)...) + puHash := fmt.Sprintf("%x", md5.Sum(pu)) + puHashSalted := append([]byte(puHash), salt...) + finalHash := fmt.Sprintf("md5%x", md5.Sum(puHashSalted)) + return []byte(finalHash) } diff --git a/auth_test.go b/auth_test.go index c2c1a8a..71c4ce5 100644 --- a/auth_test.go +++ b/auth_test.go @@ -6,8 +6,181 @@ import ( ) func TestNoPassword_authenticate(t *testing.T) { - np := &noPassword{} - actualResult := np.authenticate() + np := &authenticationNoPassword{} + actualResult, err := np.authenticate() expectedResult := msg{'R', 0, 0, 0, 8, 0, 0, 0, 0} + + require.NoError(t, err) require.Equal(t, actualResult, expectedResult) } + +func TestAuthenticationClearText_authenticate(t *testing.T) { + passwordMessage := msg{ + 'p', + 0, 0, 0, 8, + 109, 101, 104, 0, // 'meh' + } + rw := &mockMessageReadWriter{output: []msg{passwordMessage}} + args := map[string]interface{}{ + "user": "this-is-user", + } + pp := &constantPasswordProvider{password: []byte("meh")} + + a := &authenticationClearText{rw, args, pp} + + t.Run("valid password", func(t *testing.T) { + expectedResult := authOKMsg() + actualResult, err := a.authenticate() + + require.NoError(t, err) + require.Equal(t, expectedResult, actualResult) + }) + + t.Run("invalid password", func(t *testing.T) { + pp.password = []byte("shtoot") + _, err := a.authenticate() + + require.EqualError(t, err, + "Password does not match for user \"this-is-user\"") + }) + + t.Run("invalid message type", func(t *testing.T) { + a.rw = &mockMessageReadWriter{output: []msg{ + {'q', 0, 0, 0, 5, 1}, + }} + _, err := a.authenticate() + + require.EqualError(t, err, + "expected password response, got message type q") + }) +} + +func TestAuthenticationMD5_authenticate(t *testing.T) { + rw := &mockMD5MessageReadWriter{ + user: "postgres", + pass: []byte("test"), + salt: []byte{}, + } + args := map[string]interface{}{ + "user": "postgres", + } + pp := &constantPasswordProvider{password: []byte("test")} + + a := &authenticationMD5{rw, args, pp} + + t.Run("valid password", func(t *testing.T) { + expectedResult := authOKMsg() + actualResult, err := a.authenticate() + + require.NoError(t, err) + require.Equal(t, expectedResult, actualResult) + }) + + t.Run("invalid password", func(t *testing.T) { + pp.password = []byte("shtoot") + _, err := a.authenticate() + + require.EqualError(t, err, + "Password does not match for user \"postgres\"") + }) + + t.Run("invalid message type", func(t *testing.T) { + a.rw = &mockMessageReadWriter{output: []msg{ + {'q', 0, 0, 0, 5, 1}, + }} + _, err := a.authenticate() + + require.EqualError(t, err, + "expected password response, got message type q") + }) +} + +func TestHashUserPassword(t *testing.T) { + user := "postgres" + pass := []byte("test") + salt := []byte{196, 53, 49, 235} + + // actual hash received from psql using the above variables + expectedHash := []byte{ + 109, 100, 53, 97, 97, 51, 102, 56, 98, 56, + 55, 97, 57, 51, 52, 97, 52, 53, 48, 52, + 52, 101, 49, 102, 98, 50, 100, 57, 48, 55, + 48, 99, 98, 56, 48, + } + + actualHash := hashUserPassword(user, pass, salt) + require.Equal(t, expectedHash, actualHash) +} + +func TestGetRandomSalt(t *testing.T) { + var lastSalt []byte + for i := 0; i < 100; i++ { + salt := getRandomSalt() + require.Equal(t, len(salt), 4) + require.NotEqual(t, lastSalt, salt) + lastSalt = salt + } +} + +func TestExtractPassword(t *testing.T) { + t.Run("regular password", func(t *testing.T) { + passwordMessage := msg{ + 'p', + 0, 0, 0, 9, + 42, 42, 42, 42, + 0, + } + + expectedResult := []byte{42, 42, 42, 42} + actualResult := extractPassword(passwordMessage) + require.Equal(t, expectedResult, actualResult) + }) + + t.Run("empty password", func(t *testing.T) { + passwordMessage := msg{ + 'p', + 0, 0, 0, 5, + 0, + } + + expectedResult := []byte{} + actualResult := extractPassword(passwordMessage) + require.Equal(t, expectedResult, actualResult) + }) +} + +// mockMessageReadWriter implements messageReadWriter and outputs the provided output +// message by message, looped. +type mockMessageReadWriter struct { + output []msg + currentOutput int +} + +func (rw *mockMessageReadWriter) Read() (msg, error) { + return rw.output[rw.currentOutput%len(rw.output)], nil +} + +func (rw *mockMessageReadWriter) Write(m msg) error { return nil } + +// mockMD5MessageReadWriter implements messageReadWriter and outputs password +// hashed with the salt received in Write() method +type mockMD5MessageReadWriter struct { + user string + pass []byte + salt []byte +} + +func (rw *mockMD5MessageReadWriter) Read() (msg, error) { + message := msg{ + 'p', + 0, 0, 0, 25, + } + message = append(message, hashUserPassword(rw.user, rw.pass, rw.salt)...) + message = append(message, 0) + return message, nil +} + +func (rw *mockMD5MessageReadWriter) Write(m msg) error { + rw.salt = m[9:len(m)] + return nil +} diff --git a/sess.go b/sess.go index 5b9cfdc..64e45c7 100644 --- a/sess.go +++ b/sess.go @@ -95,7 +95,13 @@ func (s *session) Serve() error { s.initialized = true // handle authentication. - err = s.Write(authOKMsg()) + a := &authenticationNoPassword{} + authResponse, err := a.authenticate() + if err != nil { + return s.Write(errMsg(WithSeverity(err, "FATAL"))) + } + + err = s.Write(authResponse) if err != nil { return err }