Skip to content

Commit

Permalink
add support for multiple dial URLs
Browse files Browse the repository at this point in the history
Signed-off-by: Nicola Murino <[email protected]>
  • Loading branch information
drakkan committed May 21, 2024
1 parent 62d91b8 commit 1ac0e67
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ USAGE:
sftpgo-plugin-auth serve [command options] [arguments...]

OPTIONS:
--ldap-url value LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636 [$SFTPGO_PLUGIN_AUTH_LDAP_URL]
--ldap-url value [ --ldap-url value ] LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636. By specifying multiple URLs you will achieve load balancing and high availability [$SFTPGO_PLUGIN_AUTH_LDAP_URL]
--ldap-base-dn value The base DN defines the address of the root object in the LDAP directory, e.g dc=mylab,dc=local [$SFTPGO_PLUGIN_AUTH_LDAP_BASE_DN]
--ldap-bind-dn value The bind DN used to log in at the LDAP server in order to perform searches, e.g cn=Administrator,cn=users,dc=mylab,dc=local. This should be a read-oly user [$SFTPGO_PLUGIN_AUTH_LDAP_USERNAME, $SFTPGO_PLUGIN_AUTH_LDAP_BIND_DN]
--ldap-password value The password for the defined ldap-bind-dn. If empty an anonymous bind will be attempted [$SFTPGO_PLUGIN_AUTH_LDAP_PASSWORD]
Expand Down
6 changes: 3 additions & 3 deletions authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
sdk.WebClientPasswordResetDisabled}
)

func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int, skipTLSVerify bool,
func NewAuthenticator(dialURLs []string, baseDN, username, password string, startTLS int, skipTLSVerify bool,
baseDir string, cacheTime int, searchQuery string, groupAttributes, caCertificates []string,
primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix string, requiresGroup bool,
) (*LDAPAuthenticator, error) {
Expand All @@ -45,7 +45,7 @@ func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int,
InsecureSkipVerify: skipTLSVerify,
}
auth := &LDAPAuthenticator{
DialURL: dialURL,
DialURLs: dialURLs,
BaseDN: baseDN,
Username: username,
Password: password,
Expand All @@ -70,7 +70,7 @@ func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int,
}
startCleanupTicker(10 * time.Minute)
}
logger.AppLogger.Info("authenticator created", "dial URL", auth.DialURL, "base dn", auth.BaseDN,
logger.AppLogger.Info("authenticator created", "dial URLs", auth.DialURLs, "base dn", auth.BaseDN,
"search query", auth.SearchQuery)
return auth, nil
}
165 changes: 161 additions & 4 deletions authenticator/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"net"
"path/filepath"
"strings"
"sync"
"time"

"github.com/go-ldap/ldap/v3"
Expand All @@ -31,7 +33,7 @@ import (
)

type LDAPAuthenticator struct {
DialURL string
DialURLs []string
BaseDN string
Username string
Password string
Expand All @@ -44,10 +46,21 @@ type LDAPAuthenticator struct {
RequireGroups bool
BaseDir string
tlsConfig *tls.Config
monitorTicker *time.Ticker
cleanupDone chan bool
mu sync.RWMutex
activeURLs []string
}

func (a *LDAPAuthenticator) validate() error {
if a.DialURL == "" {
var urls []string
for _, u := range a.DialURLs {
if u != "" && !contains(urls, u) {
urls = append(urls, u)
}
}
a.DialURLs = urls
if len(a.DialURLs) == 0 {
return errors.New("ldap: dial URL is required")
}
if a.BaseDN == "" {
Expand Down Expand Up @@ -78,9 +91,96 @@ func (a *LDAPAuthenticator) validate() error {
return errors.New("group attributes not set, group prefixes are ineffective")
}
}
a.setActiveDialURLs(a.DialURLs)
return nil
}

func (a *LDAPAuthenticator) setActiveDialURLs(urls []string) {
if len(a.DialURLs) == 1 {
return
}
a.startMonitorTicker(2 * time.Minute)

a.mu.Lock()
defer a.mu.Unlock()

a.activeURLs = nil
a.activeURLs = append(a.activeURLs, urls...)
}

func (a *LDAPAuthenticator) addActiveDialURL(val string) {
if len(a.DialURLs) == 1 {
return
}
a.mu.Lock()
defer a.mu.Unlock()

if !contains(a.activeURLs, val) {
a.activeURLs = append(a.activeURLs, val)
logger.AppLogger.Info("ldap connection restored", "dial URL", val,
"number of active dial URLs", len(a.activeURLs))
}
}

func (a *LDAPAuthenticator) removeActiveDialURL(val string, err error) {
if len(a.DialURLs) == 1 {
return
}
a.mu.Lock()
defer a.mu.Unlock()

var urls []string
for _, u := range a.activeURLs {
if u != val {
urls = append(urls, u)
}
}
a.activeURLs = urls
logger.AppLogger.Error("ldap connection error", "dial URL", val, "error", err,
"number of active dial URLs", len(a.activeURLs))
}

func (a *LDAPAuthenticator) getDialURLs() []string {
if len(a.DialURLs) == 1 {
return a.DialURLs
}
a.mu.RLock()
defer a.mu.RUnlock()

if len(a.activeURLs) == 0 {
logger.AppLogger.Warn("no active dial URL, trying all the defined URLs")
return a.DialURLs
}

urls := make([]string, len(a.activeURLs))
copy(urls, a.activeURLs)

rand.Shuffle(len(urls), func(i, j int) {
urls[i], urls[j] = urls[j], urls[i]
})

return urls
}

func (a *LDAPAuthenticator) isDialURLActive(val string) bool {
a.mu.RLock()
defer a.mu.RUnlock()

return contains(a.activeURLs, val)
}

func (a *LDAPAuthenticator) monitorDialURLs() {
for _, u := range a.DialURLs {
if !a.isDialURLActive(u) {
conn, err := a.getLDAPConnection(u)
if err == nil {
conn.Close()
a.addActiveDialURL(u)
}
}
}
}

func (a *LDAPAuthenticator) CheckUserAndPass(username, password, _, _ string, userAsJSON []byte) ([]byte, error) {
if password == "" {
return nil, errInvalidCredentials
Expand Down Expand Up @@ -322,12 +422,27 @@ func (a *LDAPAuthenticator) isUserToUpdate(u *sdk.User, groups []sdk.GroupMappin
return false
}

func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
func (a *LDAPAuthenticator) connect() (conn *ldap.Conn, err error) {
for _, url := range a.getDialURLs() {
conn, err = a.getLDAPConnection(url)
if err == nil {
a.addActiveDialURL(url)
} else {
a.removeActiveDialURL(url, err)
}
if !a.isRetryableError(err) {
return
}
}
return
}

func (a *LDAPAuthenticator) getLDAPConnection(dialURL string) (*ldap.Conn, error) {
opts := []ldap.DialOpt{
ldap.DialWithDialer(&net.Dialer{Timeout: 15 * time.Second}),
ldap.DialWithTLSConfig(a.tlsConfig),
}
l, err := ldap.DialURL(a.DialURL, opts...)
l, err := ldap.DialURL(dialURL, opts...)
if err != nil {
return nil, err
}
Expand All @@ -339,3 +454,45 @@ func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
}
return l, err
}

func (*LDAPAuthenticator) isRetryableError(err error) bool {
if err == nil {
return false
}
var ldapErr *ldap.Error
if errors.As(err, &ldapErr) {
return ldapErr.ResultCode == ldap.ErrorNetwork
}
return false
}

func (a *LDAPAuthenticator) stopMonitorTicker() {
if a.monitorTicker != nil {
a.monitorTicker.Stop()
a.cleanupDone <- true
a.monitorTicker = nil
}
}

func (a *LDAPAuthenticator) startMonitorTicker(interval time.Duration) {
a.stopMonitorTicker()
a.monitorTicker = time.NewTicker(interval)
a.cleanupDone = make(chan bool)

go func() {
logger.AppLogger.Info("start monitor task for dial URLs", "dial URLs", len(a.DialURLs))
for {
select {
case <-a.cleanupDone:
logger.AppLogger.Info("monitor task for dial URLs ended")
return
case <-a.monitorTicker.C:
a.monitorDialURLs()
}
}
}()
}

func (a *LDAPAuthenticator) Cleanup() {
a.stopMonitorTicker()
}
63 changes: 58 additions & 5 deletions authenticator/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ import (
)

const (
ldapURL = "ldap://localhost:3893"
ldapsURL = "ldaps://localhost:3894"
baseDN = "dc=glauth,dc=com"
username = "cn=serviceuser,dc=glauth,dc=com"
password = "mysecret"
Expand Down Expand Up @@ -72,6 +70,12 @@ EsdgvPZR2e5IkA==
-----END CERTIFICATE-----`
)

var (
ldapURL = []string{"ldap://localhost:3893"}
ldapsURL = []string{"ldaps://localhost:3894"}
multipleLDAPURLs = []string{"ldap://localhost:3893", "ldap://localhost:3895"}
)

func TestLDAPAuthenticator(t *testing.T) {
baseDir := filepath.Clean(os.TempDir())
auth, err := NewAuthenticator(ldapURL, baseDN, username, password, 0, false, baseDir, 2, searchQuery,
Expand Down Expand Up @@ -189,7 +193,7 @@ func TestLDAPS(t *testing.T) {
}

func TestLDAPConnectionErrors(t *testing.T) {
auth, err := NewAuthenticator("ldap://localhost:3892", baseDN, username, password, 0, true, "", 0, searchQuery,
auth, err := NewAuthenticator([]string{"ldap://localhost:3892"}, baseDN, username, password, 0, true, "", 0, searchQuery,
[]string{groupAttribute}, nil, primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix, false)
require.NoError(t, err)
_, err = auth.CheckUserAndPass(user1, password, "", "", nil)
Expand All @@ -213,11 +217,14 @@ func TestStartTLS(t *testing.T) {
}

func TestValidation(t *testing.T) {
_, err := NewAuthenticator("", "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false)
_, err := NewAuthenticator(nil, "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false)
require.Error(t, err)
assert.Contains(t, err.Error(), "dial URL is required")
_, err = NewAuthenticator([]string{"", ""}, "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false)
require.Error(t, err)
assert.Contains(t, err.Error(), "dial URL is required")
a := LDAPAuthenticator{
DialURL: ldapURL,
DialURLs: ldapURL,
}
err = a.validate()
require.Error(t, err)
Expand Down Expand Up @@ -258,6 +265,13 @@ func TestValidation(t *testing.T) {
a.PrimaryGroupPrefix = "sftpgo_primary"
err = a.validate()
require.NoError(t, err)
a.DialURLs = []string{"ldap://1.2.3.4:389", "ldap://1.2.3.4:389", "ldap://1.2.3.5:389"}
err = a.validate()
require.NoError(t, err)
assert.Len(t, a.DialURLs, 2)
assert.Contains(t, a.DialURLs, "ldap://1.2.3.4:389")
assert.Contains(t, a.DialURLs, "ldap://1.2.3.5:389")
a.Cleanup()
}

func TestUnsupportedAuthMethods(t *testing.T) {
Expand Down Expand Up @@ -349,3 +363,42 @@ func TestLoadCACerts(t *testing.T) {
err = os.Remove(caCrtPath)
require.NoError(t, err)
}

func TestLDAPMonitor(t *testing.T) {
auth, err := NewAuthenticator(multipleLDAPURLs, baseDN, username, password, 0, false, "", 2, searchQuery,
[]string{groupAttribute}, nil, primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix, true)
require.NoError(t, err)
defer auth.Cleanup()

assert.Len(t, auth.getDialURLs(), 2)
auth.removeActiveDialURL(multipleLDAPURLs[0], nil)
auth.removeActiveDialURL(multipleLDAPURLs[1], nil)

auth.startMonitorTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool {
return len(auth.getDialURLs()) == 1
}, time.Second, 250*time.Millisecond)

auth.removeActiveDialURL(multipleLDAPURLs[0], nil)
auth.removeActiveDialURL(multipleLDAPURLs[1], nil)
// no active URL, all defined urls will be returned
assert.Len(t, auth.getDialURLs(), 2)
}

func TestRetryableErrors(t *testing.T) {
a := LDAPAuthenticator{}
require.False(t, a.isRetryableError(nil))

err := &ldap.Error{
Err: errNotImplemented,
ResultCode: ldap.ErrorNetwork,
}
require.True(t, a.isRetryableError(err))

err = &ldap.Error{
Err: errNotImplemented,
ResultCode: ldap.ErrorUnexpectedMessage,
}
require.False(t, a.isRetryableError(err))
require.False(t, a.isRetryableError(fs.ErrPermission))
}
Loading

0 comments on commit 1ac0e67

Please sign in to comment.