diff --git a/README.md b/README.md index 373edc9..5f500ae 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ USAGE: sftpgo-plugin-auth serve [command options] [arguments...] OPTIONS: - --ldap-url value LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636 [$SFTPGO_PLUGIN_AUTH_LDAP_URL] + --ldap-url value [ --ldap-url value ] LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636. By specifying multiple URLs you will achieve load balancing and high availability [$SFTPGO_PLUGIN_AUTH_LDAP_URL] --ldap-base-dn value The base DN defines the address of the root object in the LDAP directory, e.g dc=mylab,dc=local [$SFTPGO_PLUGIN_AUTH_LDAP_BASE_DN] --ldap-bind-dn value The bind DN used to log in at the LDAP server in order to perform searches, e.g cn=Administrator,cn=users,dc=mylab,dc=local. This should be a read-oly user [$SFTPGO_PLUGIN_AUTH_LDAP_USERNAME, $SFTPGO_PLUGIN_AUTH_LDAP_BIND_DN] --ldap-password value The password for the defined ldap-bind-dn. If empty an anonymous bind will be attempted [$SFTPGO_PLUGIN_AUTH_LDAP_PASSWORD] diff --git a/authenticator/authenticator.go b/authenticator/authenticator.go index 3a7b0b1..c38ab15 100644 --- a/authenticator/authenticator.go +++ b/authenticator/authenticator.go @@ -32,7 +32,7 @@ var ( sdk.WebClientPasswordResetDisabled} ) -func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int, skipTLSVerify bool, +func NewAuthenticator(dialURLs []string, baseDN, username, password string, startTLS int, skipTLSVerify bool, baseDir string, cacheTime int, searchQuery string, groupAttributes, caCertificates []string, primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix string, requiresGroup bool, ) (*LDAPAuthenticator, error) { @@ -45,7 +45,7 @@ func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int, InsecureSkipVerify: skipTLSVerify, } auth := &LDAPAuthenticator{ - DialURL: dialURL, + DialURLs: dialURLs, BaseDN: baseDN, Username: username, Password: password, @@ -70,7 +70,7 @@ func NewAuthenticator(dialURL, baseDN, username, password string, startTLS int, } startCleanupTicker(10 * time.Minute) } - logger.AppLogger.Info("authenticator created", "dial URL", auth.DialURL, "base dn", auth.BaseDN, + logger.AppLogger.Info("authenticator created", "dial URLs", auth.DialURLs, "base dn", auth.BaseDN, "search query", auth.SearchQuery) return auth, nil } diff --git a/authenticator/ldap.go b/authenticator/ldap.go index 4302a29..39c3702 100644 --- a/authenticator/ldap.go +++ b/authenticator/ldap.go @@ -19,9 +19,11 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "net" "path/filepath" "strings" + "sync" "time" "github.com/go-ldap/ldap/v3" @@ -31,7 +33,7 @@ import ( ) type LDAPAuthenticator struct { - DialURL string + DialURLs []string BaseDN string Username string Password string @@ -44,10 +46,21 @@ type LDAPAuthenticator struct { RequireGroups bool BaseDir string tlsConfig *tls.Config + monitorTicker *time.Ticker + cleanupDone chan bool + mu sync.RWMutex + activeURLs []string } func (a *LDAPAuthenticator) validate() error { - if a.DialURL == "" { + var urls []string + for _, u := range a.DialURLs { + if u != "" && !contains(urls, u) { + urls = append(urls, u) + } + } + a.DialURLs = urls + if len(a.DialURLs) == 0 { return errors.New("ldap: dial URL is required") } if a.BaseDN == "" { @@ -78,9 +91,96 @@ func (a *LDAPAuthenticator) validate() error { return errors.New("group attributes not set, group prefixes are ineffective") } } + a.setActiveDialURLs(a.DialURLs) return nil } +func (a *LDAPAuthenticator) setActiveDialURLs(urls []string) { + if len(a.DialURLs) == 1 { + return + } + a.startMonitorTicker(2 * time.Minute) + + a.mu.Lock() + defer a.mu.Unlock() + + a.activeURLs = nil + a.activeURLs = append(a.activeURLs, urls...) +} + +func (a *LDAPAuthenticator) addActiveDialURL(val string) { + if len(a.DialURLs) == 1 { + return + } + a.mu.Lock() + defer a.mu.Unlock() + + if !contains(a.activeURLs, val) { + a.activeURLs = append(a.activeURLs, val) + logger.AppLogger.Info("ldap connection restored", "dial URL", val, + "number of active dial URLs", len(a.activeURLs)) + } +} + +func (a *LDAPAuthenticator) removeActiveDialURL(val string, err error) { + if len(a.DialURLs) == 1 { + return + } + a.mu.Lock() + defer a.mu.Unlock() + + var urls []string + for _, u := range a.activeURLs { + if u != val { + urls = append(urls, u) + } + } + a.activeURLs = urls + logger.AppLogger.Error("ldap connection error", "dial URL", val, "error", err, + "number of active dial URLs", len(a.activeURLs)) +} + +func (a *LDAPAuthenticator) getDialURLs() []string { + if len(a.DialURLs) == 1 { + return a.DialURLs + } + a.mu.RLock() + defer a.mu.RUnlock() + + if len(a.activeURLs) == 0 { + logger.AppLogger.Warn("no active dial URL, trying all the defined URLs") + return a.DialURLs + } + + urls := make([]string, len(a.activeURLs)) + copy(urls, a.activeURLs) + + rand.Shuffle(len(urls), func(i, j int) { + urls[i], urls[j] = urls[j], urls[i] + }) + + return urls +} + +func (a *LDAPAuthenticator) isDialURLActive(val string) bool { + a.mu.RLock() + defer a.mu.RUnlock() + + return contains(a.activeURLs, val) +} + +func (a *LDAPAuthenticator) monitorDialURLs() { + for _, u := range a.DialURLs { + if !a.isDialURLActive(u) { + conn, err := a.getLDAPConnection(u) + if err == nil { + conn.Close() + a.addActiveDialURL(u) + } + } + } +} + func (a *LDAPAuthenticator) CheckUserAndPass(username, password, _, _ string, userAsJSON []byte) ([]byte, error) { if password == "" { return nil, errInvalidCredentials @@ -322,12 +422,27 @@ func (a *LDAPAuthenticator) isUserToUpdate(u *sdk.User, groups []sdk.GroupMappin return false } -func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { +func (a *LDAPAuthenticator) connect() (conn *ldap.Conn, err error) { + for _, url := range a.getDialURLs() { + conn, err = a.getLDAPConnection(url) + if err == nil { + a.addActiveDialURL(url) + } else { + a.removeActiveDialURL(url, err) + } + if !a.isRetryableError(err) { + return + } + } + return +} + +func (a *LDAPAuthenticator) getLDAPConnection(dialURL string) (*ldap.Conn, error) { opts := []ldap.DialOpt{ ldap.DialWithDialer(&net.Dialer{Timeout: 15 * time.Second}), ldap.DialWithTLSConfig(a.tlsConfig), } - l, err := ldap.DialURL(a.DialURL, opts...) + l, err := ldap.DialURL(dialURL, opts...) if err != nil { return nil, err } @@ -339,3 +454,45 @@ func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { } return l, err } + +func (*LDAPAuthenticator) isRetryableError(err error) bool { + if err == nil { + return false + } + var ldapErr *ldap.Error + if errors.As(err, &ldapErr) { + return ldapErr.ResultCode == ldap.ErrorNetwork + } + return false +} + +func (a *LDAPAuthenticator) stopMonitorTicker() { + if a.monitorTicker != nil { + a.monitorTicker.Stop() + a.cleanupDone <- true + a.monitorTicker = nil + } +} + +func (a *LDAPAuthenticator) startMonitorTicker(interval time.Duration) { + a.stopMonitorTicker() + a.monitorTicker = time.NewTicker(interval) + a.cleanupDone = make(chan bool) + + go func() { + logger.AppLogger.Info("start monitor task for dial URLs", "dial URLs", len(a.DialURLs)) + for { + select { + case <-a.cleanupDone: + logger.AppLogger.Info("monitor task for dial URLs ended") + return + case <-a.monitorTicker.C: + a.monitorDialURLs() + } + } + }() +} + +func (a *LDAPAuthenticator) Cleanup() { + a.stopMonitorTicker() +} diff --git a/authenticator/ldap_test.go b/authenticator/ldap_test.go index 2f6372e..9aec9c2 100644 --- a/authenticator/ldap_test.go +++ b/authenticator/ldap_test.go @@ -29,8 +29,6 @@ import ( ) const ( - ldapURL = "ldap://localhost:3893" - ldapsURL = "ldaps://localhost:3894" baseDN = "dc=glauth,dc=com" username = "cn=serviceuser,dc=glauth,dc=com" password = "mysecret" @@ -72,6 +70,12 @@ EsdgvPZR2e5IkA== -----END CERTIFICATE-----` ) +var ( + ldapURL = []string{"ldap://localhost:3893"} + ldapsURL = []string{"ldaps://localhost:3894"} + multipleLDAPURLs = []string{"ldap://localhost:3893", "ldap://localhost:3895"} +) + func TestLDAPAuthenticator(t *testing.T) { baseDir := filepath.Clean(os.TempDir()) auth, err := NewAuthenticator(ldapURL, baseDN, username, password, 0, false, baseDir, 2, searchQuery, @@ -189,7 +193,7 @@ func TestLDAPS(t *testing.T) { } func TestLDAPConnectionErrors(t *testing.T) { - auth, err := NewAuthenticator("ldap://localhost:3892", baseDN, username, password, 0, true, "", 0, searchQuery, + auth, err := NewAuthenticator([]string{"ldap://localhost:3892"}, baseDN, username, password, 0, true, "", 0, searchQuery, []string{groupAttribute}, nil, primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix, false) require.NoError(t, err) _, err = auth.CheckUserAndPass(user1, password, "", "", nil) @@ -213,11 +217,14 @@ func TestStartTLS(t *testing.T) { } func TestValidation(t *testing.T) { - _, err := NewAuthenticator("", "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false) + _, err := NewAuthenticator(nil, "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false) + require.Error(t, err) + assert.Contains(t, err.Error(), "dial URL is required") + _, err = NewAuthenticator([]string{"", ""}, "", "", "", 0, false, "", 0, "", nil, nil, "", "", "", false) require.Error(t, err) assert.Contains(t, err.Error(), "dial URL is required") a := LDAPAuthenticator{ - DialURL: ldapURL, + DialURLs: ldapURL, } err = a.validate() require.Error(t, err) @@ -258,6 +265,13 @@ func TestValidation(t *testing.T) { a.PrimaryGroupPrefix = "sftpgo_primary" err = a.validate() require.NoError(t, err) + a.DialURLs = []string{"ldap://1.2.3.4:389", "ldap://1.2.3.4:389", "ldap://1.2.3.5:389"} + err = a.validate() + require.NoError(t, err) + assert.Len(t, a.DialURLs, 2) + assert.Contains(t, a.DialURLs, "ldap://1.2.3.4:389") + assert.Contains(t, a.DialURLs, "ldap://1.2.3.5:389") + a.Cleanup() } func TestUnsupportedAuthMethods(t *testing.T) { @@ -349,3 +363,42 @@ func TestLoadCACerts(t *testing.T) { err = os.Remove(caCrtPath) require.NoError(t, err) } + +func TestLDAPMonitor(t *testing.T) { + auth, err := NewAuthenticator(multipleLDAPURLs, baseDN, username, password, 0, false, "", 2, searchQuery, + []string{groupAttribute}, nil, primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix, true) + require.NoError(t, err) + defer auth.Cleanup() + + assert.Len(t, auth.getDialURLs(), 2) + auth.removeActiveDialURL(multipleLDAPURLs[0], nil) + auth.removeActiveDialURL(multipleLDAPURLs[1], nil) + + auth.startMonitorTicker(100 * time.Millisecond) + assert.Eventually(t, func() bool { + return len(auth.getDialURLs()) == 1 + }, time.Second, 250*time.Millisecond) + + auth.removeActiveDialURL(multipleLDAPURLs[0], nil) + auth.removeActiveDialURL(multipleLDAPURLs[1], nil) + // no active URL, all defined urls will be returned + assert.Len(t, auth.getDialURLs(), 2) +} + +func TestRetryableErrors(t *testing.T) { + a := LDAPAuthenticator{} + require.False(t, a.isRetryableError(nil)) + + err := &ldap.Error{ + Err: errNotImplemented, + ResultCode: ldap.ErrorNetwork, + } + require.True(t, a.isRetryableError(err)) + + err = &ldap.Error{ + Err: errNotImplemented, + ResultCode: ldap.ErrorUnexpectedMessage, + } + require.False(t, a.isRetryableError(err)) + require.False(t, a.isRetryableError(fs.ErrPermission)) +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 723f321..09f463c 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -28,7 +28,7 @@ import ( ) const ( - version = "1.0.6" + version = "1.0.7-dev" envPrefix = "SFTPGO_PLUGIN_AUTH_" ) @@ -50,7 +50,7 @@ func init() { } var ( - ldapURL string + ldapURL cli.StringSlice ldapBaseDN string ldapUsername string ldapPassword string @@ -75,9 +75,9 @@ var ( Name: "serve", Usage: "Launch the SFTPGo plugin, it must be called from an SFTPGo instance", Flags: []cli.Flag{ - &cli.StringFlag{ + &cli.StringSliceFlag{ Name: "ldap-url", - Usage: "LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636", + Usage: "LDAP url, e.g ldap://192.168.1.5:389 or ldaps://192.168.1.5:636. By specifying multiple URLs you will achieve load balancing and high availability", Destination: &ldapURL, EnvVars: []string{envPrefix + "LDAP_URL"}, }, @@ -173,7 +173,7 @@ var ( }, }, Action: func(ctx *cli.Context) error { - a, err := authenticator.NewAuthenticator(ldapURL, ldapBaseDN, ldapUsername, ldapPassword, startTLS, + a, err := authenticator.NewAuthenticator(ldapURL.Value(), ldapBaseDN, ldapUsername, ldapPassword, startTLS, skipTLSVerify == 1, usersBaseDir, cacheTime, ldapSearchQuery, ldapGroupAttributes.Value(), caCertificates.Value(), primaryGroupPrefix, secondaryGroupPrefix, membershipGroupPrefix, requireGroupMembership) @@ -189,6 +189,7 @@ var ( GRPCServer: plugin.DefaultGRPCServer, }) + a.Cleanup() return errors.New("the plugin exited unexpectedly") }, },