Skip to content

Commit

Permalink
Use atomic pointer for map access synchronization (#12)
Browse files Browse the repository at this point in the history
* use atomic pointer for map access synchronization

Signed-off-by: Vladislav Yarmak <[email protected]>

* restore coverage

Signed-off-by: Vladislav Yarmak <[email protected]>

---------

Signed-off-by: Vladislav Yarmak <[email protected]>
  • Loading branch information
Snawoot authored Oct 29, 2024
1 parent 27f6cbb commit 383357b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
18 changes: 5 additions & 13 deletions htgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"io"
"os"
"strings"
"sync"
"sync/atomic"
)

// Data structure for users and theirs groups (map).
Expand All @@ -26,8 +26,7 @@ type userGroupMap map[string][]string
// A HTGroup encompasses an Apache-style group file.
type HTGroup struct {
filePath string
mutex sync.RWMutex
userGroups userGroupMap
userGroups atomic.Pointer[userGroupMap]
}

// NewGroups creates a HTGroup from an Apache-style group file.
Expand Down Expand Up @@ -56,10 +55,7 @@ func NewGroupsFromReader(r io.Reader, bad BadLineHandler) (*HTGroup, error) {

// ReloadGroups rereads the group file.
func (htGroup *HTGroup) ReloadGroups(bad BadLineHandler) error {
htGroup.mutex.Lock()
filename := htGroup.filePath
htGroup.mutex.Unlock()
file, err := os.Open(filename)
file, err := os.Open(htGroup.filePath)
if err != nil {
return err
}
Expand All @@ -83,9 +79,7 @@ func (htGroup *HTGroup) ReloadGroupsFromReader(r io.Reader, bad BadLineHandler)
return fmt.Errorf("Error scanning group file: %s", scannerErr.Error())
}

htGroup.mutex.Lock()
htGroup.userGroups = userGroups
htGroup.mutex.Unlock()
htGroup.userGroups.Store(&userGroups)

return nil
}
Expand Down Expand Up @@ -123,9 +117,7 @@ func (htGroup *HTGroup) IsUserInGroup(user string, group string) bool {
// GetUserGroups reads all groups of a user.
// Returns all groups as a string array or an empty array.
func (htGroup *HTGroup) GetUserGroups(user string) []string {
htGroup.mutex.RLock()
groups := htGroup.userGroups[user]
htGroup.mutex.RUnlock()
groups := (*htGroup.userGroups.Load())[user]

if groups == nil {
return []string{}
Expand Down
17 changes: 17 additions & 0 deletions htgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package htpasswd

import (
"os"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -66,4 +67,20 @@ func TestGroups(t *testing.T) {
assert.Len(t, htGroup.GetUserGroups("user2"), 2)
assert.Len(t, htGroup.GetUserGroups("user3"), 1)
assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0)

// Test load from reader as well
r := strings.NewReader(contents2)
htGroup, err = NewGroupsFromReader(r, nil)
assert.NoError(t, err)
assert.True(t, htGroup.IsUserInGroup("user1", "users"))
assert.True(t, htGroup.IsUserInGroup("user1", "admins"))
assert.True(t, htGroup.IsUserInGroup("user2", "users"))
assert.True(t, htGroup.IsUserInGroup("user2", "admins"))
assert.False(t, htGroup.IsUserInGroup("unknownuser", "users"))
assert.False(t, htGroup.IsUserInGroup("user1", "unknowngroup"))
assert.False(t, htGroup.IsUserInGroup("unknownuser", "unknowngroup"))
assert.Len(t, htGroup.GetUserGroups("user1"), 2)
assert.Len(t, htGroup.GetUserGroups("user2"), 2)
assert.Len(t, htGroup.GetUserGroups("user3"), 1)
assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0)
}
13 changes: 4 additions & 9 deletions htpasswd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"io"
"os"
"strings"
"sync"
"sync/atomic"
)

// An EncodedPasswd is created from the encoded password in a password file by a PasswdParser.
Expand Down Expand Up @@ -53,8 +53,7 @@ type BadLineHandler func(err error)
// An File encompasses an Apache-style htpasswd file for HTTP Basic authentication
type File struct {
filePath string
mutex sync.RWMutex
passwds passwdTable
passwds atomic.Pointer[passwdTable]
parsers []PasswdParser
}

Expand Down Expand Up @@ -104,9 +103,7 @@ func NewFromReader(r io.Reader, parsers []PasswdParser, bad BadLineHandler) (*Fi
// Match checks the username and password combination to see if it represents
// a valid account from the htpassword file.
func (bf *File) Match(username, password string) bool {
bf.mutex.RLock()
matcher, ok := bf.passwds[username]
bf.mutex.RUnlock()
matcher, ok := (*bf.passwds.Load())[username]

if ok && matcher.MatchesPassword(password) {
// we are good
Expand Down Expand Up @@ -154,9 +151,7 @@ func (bf *File) ReloadFromReader(r io.Reader, bad BadLineHandler) error {
}

// .. finally, safely swap in the new map
bf.mutex.Lock()
bf.passwds = newPasswdMap
bf.mutex.Unlock()
bf.passwds.Store(&newPasswdMap)

return nil
}
Expand Down

0 comments on commit 383357b

Please sign in to comment.