From 383357bdd60efcacc0ae72e33f8bf7afa9637760 Mon Sep 17 00:00:00 2001 From: Snawoot Date: Tue, 29 Oct 2024 20:21:52 +0200 Subject: [PATCH] Use atomic pointer for map access synchronization (#12) * use atomic pointer for map access synchronization Signed-off-by: Vladislav Yarmak * restore coverage Signed-off-by: Vladislav Yarmak --------- Signed-off-by: Vladislav Yarmak --- htgroup.go | 18 +++++------------- htgroup_test.go | 17 +++++++++++++++++ htpasswd.go | 13 ++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/htgroup.go b/htgroup.go index b217950..8ea7555 100644 --- a/htgroup.go +++ b/htgroup.go @@ -16,7 +16,7 @@ import ( "io" "os" "strings" - "sync" + "sync/atomic" ) // Data structure for users and theirs groups (map). @@ -26,8 +26,7 @@ type userGroupMap map[string][]string // A HTGroup encompasses an Apache-style group file. type HTGroup struct { filePath string - mutex sync.RWMutex - userGroups userGroupMap + userGroups atomic.Pointer[userGroupMap] } // NewGroups creates a HTGroup from an Apache-style group file. @@ -56,10 +55,7 @@ func NewGroupsFromReader(r io.Reader, bad BadLineHandler) (*HTGroup, error) { // ReloadGroups rereads the group file. func (htGroup *HTGroup) ReloadGroups(bad BadLineHandler) error { - htGroup.mutex.Lock() - filename := htGroup.filePath - htGroup.mutex.Unlock() - file, err := os.Open(filename) + file, err := os.Open(htGroup.filePath) if err != nil { return err } @@ -83,9 +79,7 @@ func (htGroup *HTGroup) ReloadGroupsFromReader(r io.Reader, bad BadLineHandler) return fmt.Errorf("Error scanning group file: %s", scannerErr.Error()) } - htGroup.mutex.Lock() - htGroup.userGroups = userGroups - htGroup.mutex.Unlock() + htGroup.userGroups.Store(&userGroups) return nil } @@ -123,9 +117,7 @@ func (htGroup *HTGroup) IsUserInGroup(user string, group string) bool { // GetUserGroups reads all groups of a user. // Returns all groups as a string array or an empty array. func (htGroup *HTGroup) GetUserGroups(user string) []string { - htGroup.mutex.RLock() - groups := htGroup.userGroups[user] - htGroup.mutex.RUnlock() + groups := (*htGroup.userGroups.Load())[user] if groups == nil { return []string{} diff --git a/htgroup_test.go b/htgroup_test.go index 2bfeefa..438f673 100644 --- a/htgroup_test.go +++ b/htgroup_test.go @@ -2,6 +2,7 @@ package htpasswd import ( "os" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -66,4 +67,20 @@ func TestGroups(t *testing.T) { assert.Len(t, htGroup.GetUserGroups("user2"), 2) assert.Len(t, htGroup.GetUserGroups("user3"), 1) assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0) + + // Test load from reader as well + r := strings.NewReader(contents2) + htGroup, err = NewGroupsFromReader(r, nil) + assert.NoError(t, err) + assert.True(t, htGroup.IsUserInGroup("user1", "users")) + assert.True(t, htGroup.IsUserInGroup("user1", "admins")) + assert.True(t, htGroup.IsUserInGroup("user2", "users")) + assert.True(t, htGroup.IsUserInGroup("user2", "admins")) + assert.False(t, htGroup.IsUserInGroup("unknownuser", "users")) + assert.False(t, htGroup.IsUserInGroup("user1", "unknowngroup")) + assert.False(t, htGroup.IsUserInGroup("unknownuser", "unknowngroup")) + assert.Len(t, htGroup.GetUserGroups("user1"), 2) + assert.Len(t, htGroup.GetUserGroups("user2"), 2) + assert.Len(t, htGroup.GetUserGroups("user3"), 1) + assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0) } diff --git a/htpasswd.go b/htpasswd.go index 44665dd..838935c 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -19,7 +19,7 @@ import ( "io" "os" "strings" - "sync" + "sync/atomic" ) // An EncodedPasswd is created from the encoded password in a password file by a PasswdParser. @@ -53,8 +53,7 @@ type BadLineHandler func(err error) // An File encompasses an Apache-style htpasswd file for HTTP Basic authentication type File struct { filePath string - mutex sync.RWMutex - passwds passwdTable + passwds atomic.Pointer[passwdTable] parsers []PasswdParser } @@ -104,9 +103,7 @@ func NewFromReader(r io.Reader, parsers []PasswdParser, bad BadLineHandler) (*Fi // Match checks the username and password combination to see if it represents // a valid account from the htpassword file. func (bf *File) Match(username, password string) bool { - bf.mutex.RLock() - matcher, ok := bf.passwds[username] - bf.mutex.RUnlock() + matcher, ok := (*bf.passwds.Load())[username] if ok && matcher.MatchesPassword(password) { // we are good @@ -154,9 +151,7 @@ func (bf *File) ReloadFromReader(r io.Reader, bad BadLineHandler) error { } // .. finally, safely swap in the new map - bf.mutex.Lock() - bf.passwds = newPasswdMap - bf.mutex.Unlock() + bf.passwds.Store(&newPasswdMap) return nil }