Skip to content

Commit

Permalink
Add health-check support to loadbalancer
Browse files Browse the repository at this point in the history
* Adds support for health-checking loadbalancer servers. If a
  health-check fails when dialing, all existing connections to the
  server will be closed.
* Wires up a remotedialer tunnel connectivity check as the health check
  for supervisor/apiserver connections.
* Wires up a simple ping request to the supervisor port as the health
  check for etcd connections.

Signed-off-by: Brad Davidson <[email protected]>
  • Loading branch information
brandond committed Mar 27, 2024
1 parent edb0440 commit c51d7bf
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 62 deletions.
2 changes: 1 addition & 1 deletion pkg/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N
// If the supervisor and externally-facing apiserver are not on the same port, tell the proxy where to find the apiserver.
if controlConfig.SupervisorPort != controlConfig.HTTPSPort {
isIPv6 := utilsnet.IsIPv6(net.ParseIP([]string{envInfo.NodeIP.String()}[0]))
if err := proxy.SetAPIServerPort(ctx, controlConfig.HTTPSPort, isIPv6); err != nil {
if err := proxy.SetAPIServerPort(controlConfig.HTTPSPort, isIPv6); err != nil {
return nil, errors.Wrapf(err, "failed to setup access to API Server port %d on at %s", controlConfig.HTTPSPort, proxy.SupervisorURL())
}
}
Expand Down
35 changes: 13 additions & 22 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import (

// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}

Expand All @@ -31,7 +34,9 @@ type serverConn struct {
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
mutex sync.Mutex
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy

serviceName string
Expand Down Expand Up @@ -123,26 +128,9 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)

return lb, nil
}

func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})}
}
go lb.runHealthChecks(ctx)

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
return lb, nil
}

func (lb *LoadBalancer) Update(serverAddresses []string) {
Expand All @@ -166,15 +154,18 @@ func (lb *LoadBalancer) LoadBalancerServerURL() string {
return lb.localServerURL
}

func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()

startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress

server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else {
} else if server.healthCheck() {
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
Expand Down
74 changes: 67 additions & 7 deletions pkg/agent/loadbalancer/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"os"
"strconv"
"time"

"github.com/k3s-io/k3s/pkg/version"
http_dialer "github.com/mwitkow/go-http-dialer"
Expand All @@ -17,6 +18,7 @@ import (

"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
)

var defaultDialer proxy.Dialer = &net.Dialer{}
Expand Down Expand Up @@ -73,7 +75,11 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {

for addedServer := range newAddresses.Difference(curAddresses) {
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
lb.servers[addedServer] = &server{connections: make(map[net.Conn]struct{})}
lb.servers[addedServer] = &server{
address: addedServer,
connections: make(map[net.Conn]struct{}),
healthCheck: func() bool { return true },
}
}

for removedServer := range curAddresses.Difference(newAddresses) {
Expand Down Expand Up @@ -106,8 +112,8 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
}

func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
lb.mutex.RLock()
defer lb.mutex.RUnlock()

if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
Expand Down Expand Up @@ -162,10 +168,12 @@ func (s *server) closeAll() {
s.mutex.Lock()
defer s.mutex.Unlock()

logrus.Debugf("Closing %d connections to load balancer server", len(s.connections))
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
}
}
}

Expand All @@ -178,3 +186,55 @@ func (sc *serverConn) Close() error {
delete(sc.server.connections, sc)
return sc.Conn.Close()
}

// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return true },
connections: make(map[net.Conn]struct{}),
}
}

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}

// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}

// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
wait.Until(func() {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
for _, server := range lb.servers {
if !server.healthCheck() {
defer server.closeAll()
}
}
}, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
}
26 changes: 22 additions & 4 deletions pkg/agent/proxy/apiproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"context"
"net"
sysnet "net"
"net/url"
"strconv"
Expand All @@ -14,13 +15,14 @@ import (

type Proxy interface {
Update(addresses []string)
SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) error
SetAPIServerPort(port int, isIPv6 bool) error
SetSupervisorDefault(address string)
IsSupervisorLBEnabled() bool
SupervisorURL() string
SupervisorAddresses() []string
APIServerURL() string
IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool)
}

// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
Expand All @@ -38,6 +40,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
supervisorURL: supervisorURL,
apiServerURL: supervisorURL,
lbServerPort: lbServerPort,
context: ctx,
}

if lbEnabled {
Expand Down Expand Up @@ -70,6 +73,7 @@ type proxy struct {
apiServerEnabled bool

apiServerURL string
apiServerPort string
supervisorURL string
supervisorPort string
initialSupervisorURL string
Expand All @@ -78,6 +82,7 @@ type proxy struct {

apiServerLB *loadbalancer.LoadBalancer
supervisorLB *loadbalancer.LoadBalancer
context context.Context
}

func (p *proxy) Update(addresses []string) {
Expand All @@ -96,6 +101,18 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses
}

func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck)
}

if p.apiServerLB != nil {
host, _, _ := net.SplitHostPort(address)
address = net.JoinHostPort(host, p.apiServerPort)
p.apiServerLB.SetHealthCheck(address, healthCheck)
}
}

func (p *proxy) setSupervisorPort(addresses []string) []string {
var newAddresses []string
for _, address := range addresses {
Expand All @@ -114,12 +131,13 @@ func (p *proxy) setSupervisorPort(addresses []string) []string {
// load-balancing is enabled, another load-balancer is started on a port one below the supervisor
// load-balancer, and the address of this load-balancer is returned instead of the actual apiserver
// addresses.
func (p *proxy) SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) error {
func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
u, err := url.Parse(p.initialSupervisorURL)
if err != nil {
return errors.Wrapf(err, "failed to parse server URL %s", p.initialSupervisorURL)
}
u.Host = sysnet.JoinHostPort(u.Hostname(), strconv.Itoa(port))
p.apiServerPort = strconv.Itoa(port)
u.Host = sysnet.JoinHostPort(u.Hostname(), p.apiServerPort)

p.apiServerURL = u.String()
p.apiServerEnabled = true
Expand All @@ -129,7 +147,7 @@ func (p *proxy) SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) err
if lbServerPort != 0 {
lbServerPort = lbServerPort - 1
}
lb, err := loadbalancer.New(ctx, p.dataDir, loadbalancer.APIServerServiceName, p.apiServerURL, lbServerPort, isIPv6)
lb, err := loadbalancer.New(p.context, p.dataDir, loadbalancer.APIServerServiceName, p.apiServerURL, lbServerPort, isIPv6)
if err != nil {
return err
}
Expand Down
56 changes: 43 additions & 13 deletions pkg/agent/tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tunnel
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -289,7 +290,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
disconnect := map[string]context.CancelFunc{}
for _, address := range proxy.SupervisorAddresses() {
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, wg, address, tlsConfig)
conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -361,7 +364,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
for _, address := range proxy.SupervisorAddresses() {
validEndpoint[address] = true
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, nil, address, tlsConfig)
conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -403,32 +408,54 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
return false
}

type agentConnection struct {
cancel context.CancelFunc
connected func() bool
}

// connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) context.CancelFunc {
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}

// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true

once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
}

ctx, cancel := context.WithCancel(rootCtx)
auth := func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}

onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
}

// Start remotedialer connect loop in a goroutine to ensure a connection to the target server
go func() {
for {
remotedialer.ClientConnect(ctx, wsURL, nil, ws, func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}, func(_ context.Context, _ *remotedialer.Session) error {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
})

// ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxy(ctx, wsURL, nil, auth, ws, onConnect)
connected = false
if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconecting...")
// wait between reconnection attempts to avoid hammering the server
time.Sleep(endpointDebounceDelay)
}
// If the context has been cancelled, exit the goroutine instead of retrying
if ctx.Err() != nil {
if waitGroup != nil {
once.Do(waitGroup.Done)
Expand All @@ -438,7 +465,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}
}()

return cancel
return agentConnection{
cancel: cancel,
connected: func() bool { return connected },
}
}

// isKubeletPort returns true if the connection is to a reserved TCP port on a loopback address.
Expand Down
Loading

0 comments on commit c51d7bf

Please sign in to comment.