Skip to content

Commit

Permalink
Fix zookeeper ip address change with tls
Browse files Browse the repository at this point in the history
  • Loading branch information
secwall committed Oct 21, 2024
1 parent a5b83fe commit 2ca7562
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 46 deletions.
4 changes: 2 additions & 2 deletions internal/dcs/zk.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *slog.Log

var operation func() error

hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, proxyLogger)
hostProvider := NewRandomHostProvider(ctx, &config.RandomHostProvider, !config.UseSSL, proxyLogger)

if config.UseSSL {
if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" {
Expand All @@ -85,7 +85,7 @@ func NewZookeeper(ctx context.Context, config *ZookeeperConfig, logger *slog.Log
return nil, err
}
baseDialer := net.Dialer{Timeout: config.SessionTimeout}
dialer, err := GetTLSDialer(config.Hosts, &baseDialer, tlsConfig)
dialer, err := GetTLSDialer(&baseDialer, tlsConfig)
if err != nil {
return nil, err
}
Expand Down
10 changes: 8 additions & 2 deletions internal/dcs/zk_host_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type zkhost struct {
type RandomHostProvider struct {
ctx context.Context
hosts sync.Map
useAddrs bool
hostsKeys []string
tried map[string]struct{}
logger *slog.Logger
Expand All @@ -27,7 +28,7 @@ type RandomHostProvider struct {
resolver *net.Resolver
}

func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, logger *slog.Logger) *RandomHostProvider {
func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig, useAddrs bool, logger *slog.Logger) *RandomHostProvider {
return &RandomHostProvider{
ctx: ctx,
lookupTTL: config.LookupTTL,
Expand All @@ -37,6 +38,7 @@ func NewRandomHostProvider(ctx context.Context, config *RandomHostProviderConfig
tried: make(map[string]struct{}),
hosts: sync.Map{},
resolver: &net.Resolver{},
useAddrs: useAddrs,
}
}

Expand Down Expand Up @@ -146,7 +148,11 @@ func (rhp *RandomHostProvider) Next() (server string, retryStart bool) {
zhost := host.(zkhost)

if len(zhost.resolved) > 0 {
ret = zhost.resolved[rand.Intn(len(zhost.resolved))]
if rhp.useAddrs {
ret = zhost.resolved[rand.Intn(len(zhost.resolved))]
} else {
ret = selected
}
}
}

Expand Down
44 changes: 2 additions & 42 deletions internal/dcs/zk_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,13 @@ package dcs
import (
"crypto/tls"
"crypto/x509"
"errors"
"net"
"os"
"time"

"github.com/go-zookeeper/zk"
)

// TODO: if pr https://github.com/go-zookeeper/zk/pull/106 will be merged
// remove this file and use same functions from go-zookeeper/zk
func addrsByHostname(server string) ([]string, error) {
res := []string{}
host, port, err := net.SplitHostPort(server)
if err != nil {
return nil, err
}
addrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for _, addr := range addrs {
res = append(res, net.JoinHostPort(addr, port))
}
return res, nil
}

func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error) {
rootCABytes, err := os.ReadFile(rootCAFile)
if err != nil {
Expand All @@ -52,29 +33,8 @@ func CreateTLSConfig(rootCAFile, certFile, keyFile string) (*tls.Config, error)
}, nil
}

func GetTLSDialer(servers []string, dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) {
if len(servers) == 0 {
return nil, errors.New("zk: server list must not be empty")
}
srvs := zk.FormatServers(servers)

addrToHostname := map[string]string{}

for _, server := range srvs {
ips, err := addrsByHostname(server)
if err != nil {
return nil, err
}
for _, ip := range ips {
addrToHostname[ip] = server
}
}

func GetTLSDialer(dialer *net.Dialer, tlsConfig *tls.Config) (zk.Dialer, error) {
return func(network, address string, _ time.Duration) (net.Conn, error) {
server, ok := addrToHostname[address]
if !ok {
server = address
}
return tls.DialWithDialer(dialer, network, server, tlsConfig)
return tls.DialWithDialer(dialer, network, address, tlsConfig)
}, nil
}

0 comments on commit 2ca7562

Please sign in to comment.