diff --git a/internal/app/app.go b/internal/app/app.go index d0c944f..2ed493e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -92,7 +92,7 @@ func NewApp(configFile, logLevel string) (*App, error) { func (app *App) connectDCS() error { var err error - app.dcs, err = dcs.NewZookeeper(&app.config.Zookeeper, app.logger) + app.dcs, err = dcs.NewZookeeper(app.ctx, &app.config.Zookeeper, app.logger) if err != nil { return fmt.Errorf("failed to connect to zkDCS: %s", err.Error()) } diff --git a/internal/dcs/config.go b/internal/dcs/config.go index 2f185c1..985d553 100644 --- a/internal/dcs/config.go +++ b/internal/dcs/config.go @@ -31,14 +31,16 @@ type ZookeeperConfig struct { } type RandomHostProviderConfig struct { - LookupTimeout time.Duration `config:"lookup_timeout" yaml:"lookup_timeout"` - LookupTTL time.Duration `config:"lookup_ttl" yaml:"lookup_ttl"` + LookupTimeout time.Duration `config:"lookup_timeout" yaml:"lookup_timeout"` + LookupTTL time.Duration `config:"lookup_ttl" yaml:"lookup_ttl"` + LookupTickInterval time.Duration `config:"lookup_tick_interval" yaml:"lookup_tick_interval"` } func DefaultRandomHostProviderConfig() RandomHostProviderConfig { return RandomHostProviderConfig{ - LookupTimeout: 3 * time.Second, - LookupTTL: 300 * time.Second, + LookupTimeout: 3 * time.Second, + LookupTTL: 300 * time.Second, + LookupTickInterval: 60 * time.Second, } } diff --git a/internal/dcs/zk.go b/internal/dcs/zk.go index 00c9419..5e34d96 100644 --- a/internal/dcs/zk.go +++ b/internal/dcs/zk.go @@ -1,6 +1,7 @@ package dcs import ( + "context" "encoding/json" "fmt" "log/slog" @@ -50,7 +51,7 @@ func retry(config *ZookeeperConfig, operation func() error) error { } // NewZookeeper returns Zookeeper based DCS storage -func NewZookeeper(config *ZookeeperConfig, logger *slog.Logger) (DCS, error) { +func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *slog.Logger) (DCS, error) { if len(config.Hosts) == 0 { return nil, fmt.Errorf("zookeeper not configured, fill zookeeper/hosts in config") } @@ -72,7 +73,7 @@ func NewZookeeper(config *ZookeeperConfig, logger *slog.Logger) (DCS, error) { var operation func() error - hostProvider := NewRandomHostProvider(&config.RandomHostProvider, logger) + hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, logger) if config.UseSSL { if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" { diff --git a/internal/dcs/zk_host_provider.go b/internal/dcs/zk_host_provider.go index 93d1297..577c892 100644 --- a/internal/dcs/zk_host_provider.go +++ b/internal/dcs/zk_host_provider.go @@ -10,119 +10,150 @@ import ( "time" ) +type zkhost struct { + resolved []string + lastLookup time.Time +} + type RandomHostProvider struct { - lock sync.Mutex - servers []string - resolved []string - tried map[string]struct{} - logger *slog.Logger - lastLookup time.Time - lookupTTL time.Duration - lookupTimeout time.Duration - resolver *net.Resolver + ctx context.Context + hosts sync.Map + hostsKeys []string + tried map[string]struct{} + logger *slog.Logger + lookupTTL time.Duration + lookupTimeout time.Duration + lookupTickInterval time.Duration + resolver *net.Resolver } -func NewRandomHostProvider(config *RandomHostProviderConfig, logger *slog.Logger) *RandomHostProvider { +func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, logger *slog.Logger) *RandomHostProvider { return &RandomHostProvider{ - lookupTTL: config.LookupTTL, - lookupTimeout: config.LookupTimeout, - logger: logger, - tried: make(map[string]struct{}), - resolver: &net.Resolver{}, + ctx: ctx, + lookupTTL: config.LookupTTL, + lookupTimeout: config.LookupTimeout, + lookupTickInterval: config.LookupTickInterval, + logger: logger, + tried: make(map[string]struct{}), + hosts: sync.Map{}, + resolver: &net.Resolver{}, } } func (rhp *RandomHostProvider) Init(servers []string) error { - rhp.lock.Lock() - defer rhp.lock.Unlock() + numResolved := 0 - rhp.servers = servers - - err := rhp.resolveHosts() + for _, host := range servers { + resolved, err := rhp.resolveHost(host) + if err != nil { + rhp.logger.Error(fmt.Sprintf("host definition %s is invalid", host), "error", err) + continue + } + numResolved += len(resolved) + rhp.hosts.Store(host, zkhost{ + resolved: resolved, + lastLookup: time.Now(), + }) + rhp.hostsKeys = append(rhp.hostsKeys, host) + } - if err != nil { - return fmt.Errorf("failed to init zk host provider %v", err) + if numResolved == 0 { + return fmt.Errorf("unable to resolve any host from %v", servers) } + go rhp.resolveHosts() + return nil } -func (rhp *RandomHostProvider) resolveHosts() error { - resolved := []string{} - for _, server := range rhp.servers { - host, port, err := net.SplitHostPort(server) - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(context.Background(), rhp.lookupTimeout) - defer cancel() - addrs, err := rhp.resolver.LookupHost(ctx, host) - if err != nil { - rhp.logger.Error(fmt.Sprintf("unable to resolve %s", host), "error", err) - } - for _, addr := range addrs { - resolved = append(resolved, net.JoinHostPort(addr, port)) +func (rhp *RandomHostProvider) resolveHosts() { + ticker := time.NewTicker(rhp.lookupTickInterval) + for { + select { + case <-ticker.C: + for _, pair := range rhp.hostsKeys { + host, _ := rhp.hosts.Load(pair) + zhost := host.(zkhost) + + if len(zhost.resolved) == 0 || time.Since(zhost.lastLookup) > rhp.lookupTTL { + resolved, err := rhp.resolveHost(pair) + if err != nil || len(resolved) == 0 { + rhp.logger.Error(fmt.Sprintf("background resolve for %s failed", pair), "error", err) + continue + } + rhp.hosts.Store(pair, zkhost{ + resolved: resolved, + lastLookup: time.Now(), + }) + } + } + case <-rhp.ctx.Done(): + return } } +} - if len(resolved) == 0 { - return fmt.Errorf("no hosts resolved for %q", rhp.servers) +func (rhp *RandomHostProvider) resolveHost(pair string) ([]string, error) { + var res []string + host, port, err := net.SplitHostPort(pair) + if err != nil { + return res, err + } + ctx, cancel := context.WithTimeout(rhp.ctx, rhp.lookupTimeout) + defer cancel() + addrs, err := rhp.resolver.LookupHost(ctx, host) + if err != nil { + rhp.logger.Error(fmt.Sprintf("unable to resolve %s", host), "error", err) + } + for _, addr := range addrs { + res = append(res, net.JoinHostPort(addr, port)) } - rhp.lastLookup = time.Now() - rhp.resolved = resolved - - rand.Shuffle(len(rhp.resolved), func(i, j int) { rhp.resolved[i], rhp.resolved[j] = rhp.resolved[j], rhp.resolved[i] }) - - return nil + return res, nil } func (rhp *RandomHostProvider) Len() int { - rhp.lock.Lock() - defer rhp.lock.Unlock() - return len(rhp.resolved) + return len(rhp.hostsKeys) } func (rhp *RandomHostProvider) Next() (server string, retryStart bool) { - rhp.lock.Lock() - defer rhp.lock.Unlock() - lastTime := time.Since(rhp.lastLookup) needRetry := false - if lastTime > rhp.lookupTTL { - err := rhp.resolveHosts() - if err != nil { - rhp.logger.Error("resolve zk hosts failed", "error", err) - } - } - notTried := []string{} + var ret string - for _, addr := range rhp.resolved { - if _, ok := rhp.tried[addr]; !ok { - notTried = append(notTried, addr) + for len(ret) == 0 { + notTried := []string{} + + for _, host := range rhp.hostsKeys { + if _, ok := rhp.tried[host]; !ok { + notTried = append(notTried, host) + } } - } - var selected string + var selected string + if len(notTried) == 0 { + needRetry = true + for k := range rhp.tried { + delete(rhp.tried, k) + } + selected = rhp.hostsKeys[rand.Intn(len(rhp.hostsKeys))] + } else { + selected = notTried[rand.Intn(len(notTried))] + } + rhp.tried[selected] = struct{}{} + + host, _ := rhp.hosts.Load(selected) + zhost := host.(zkhost) - if len(notTried) == 0 { - needRetry = true - for k := range rhp.tried { - delete(rhp.tried, k) + if len(zhost.resolved) > 0 { + ret = zhost.resolved[rand.Intn(len(zhost.resolved))] } - selected = rhp.resolved[rand.Intn(len(rhp.resolved))] - } else { - selected = notTried[rand.Intn(len(notTried))] } - rhp.tried[selected] = struct{}{} - - return selected, needRetry + return ret, needRetry } func (rhp *RandomHostProvider) Connected() { - rhp.lock.Lock() - defer rhp.lock.Unlock() for k := range rhp.tried { delete(rhp.tried, k) }