diff --git a/internal/dcs/zk.go b/internal/dcs/zk.go index 0e9cea6..c19fb28 100644 --- a/internal/dcs/zk.go +++ b/internal/dcs/zk.go @@ -74,7 +74,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *slog.Log var operation func() error - hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, proxyLogger) + hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, !config.UseSSL, proxyLogger) if config.UseSSL { if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" { @@ -85,7 +85,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *slog.Log return nil, err } baseDialer := net.Dialer{Timeout: config.SessionTimeout} - dialer, err := GetTLSDialer(config.Hosts, &baseDialer, tlsConfig) + dialer, err := GetTLSDialer(&baseDialer, tlsConfig) if err != nil { return nil, err } diff --git a/internal/dcs/zk_host_provider.go b/internal/dcs/zk_host_provider.go index 577c892..f31f82b 100644 --- a/internal/dcs/zk_host_provider.go +++ b/internal/dcs/zk_host_provider.go @@ -18,6 +18,7 @@ type zkhost struct { type RandomHostProvider struct { ctx context.Context hosts sync.Map + useAddrs bool hostsKeys []string tried map[string]struct{} logger *slog.Logger @@ -27,7 +28,7 @@ type RandomHostProvider struct { resolver *net.Resolver } -func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, logger *slog.Logger) *RandomHostProvider { +func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, useAddrs bool, logger *slog.Logger) *RandomHostProvider { return &RandomHostProvider{ ctx: ctx, lookupTTL: config.LookupTTL, @@ -37,6 +38,7 @@ func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig tried: make(map[string]struct{}), hosts: sync.Map{}, resolver: &net.Resolver{}, + useAddrs: useAddrs, } } @@ -146,7 +148,11 @@ func (rhp *RandomHostProvider) Next() (server string, retryStart bool) { zhost := host.(zkhost) if len(zhost.resolved) > 0 { - ret = zhost.resolved[rand.Intn(len(zhost.resolved))] + if rhp.useAddrs { + ret = zhost.resolved[rand.Intn(len(zhost.resolved))] + } else { + ret = selected + } } } diff --git a/internal/dcs/zk_tls.go b/internal/dcs/zk_tls.go index b3d00c1..9fdadea 100644 --- a/internal/dcs/zk_tls.go +++ b/internal/dcs/zk_tls.go @@ -3,7 +3,6 @@ package dcs import ( "crypto/tls" "crypto/x509" - "errors" "net" "os" "time" @@ -11,24 +10,6 @@ import ( "github.com/go-zookeeper/zk" ) -// TODO: if pr https://github.com/go-zookeeper/zk/pull/106 will be merged -// remove this file and use same functions from go-zookeeper/zk -func addrsByHostname(server string) ([]string, error) { - res := []string{} - host, port, err := net.SplitHostPort(server) - if err != nil { - return nil, err - } - addrs, err := net.LookupHost(host) - if err != nil { - return nil, err - } - for _, addr := range addrs { - res = append(res, net.JoinHostPort(addr, port)) - } - return res, nil -} - func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) { rootCABytes, err := os.ReadFile(rootCAFile) if err != nil { @@ -52,29 +33,8 @@ func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) }, nil } -func GetTLSDialer(servers []string, dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) { - if len(servers) == 0 { - return nil, errors.New("zk: server list must not be empty") - } - srvs := zk.FormatServers(servers) - - addrToHostname := map[string]string{} - - for _, server := range srvs { - ips, err := addrsByHostname(server) - if err != nil { - return nil, err - } - for _, ip := range ips { - addrToHostname[ip] = server - } - } - +func GetTLSDialer(dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) { return func(network, address string, _ time.Duration) (net.Conn, error) { - server, ok := addrToHostname[address] - if !ok { - server = address - } - return tls.DialWithDialer(dialer, network, server, tlsConfig) + return tls.DialWithDialer(dialer, network, address, tlsConfig) }, nil }