Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add health-check support to loadbalancer #9757

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N
// If the supervisor and externally-facing apiserver are not on the same port, tell the proxy where to find the apiserver.
if controlConfig.SupervisorPort != controlConfig.HTTPSPort {
isIPv6 := utilsnet.IsIPv6(net.ParseIP([]string{envInfo.NodeIP.String()}[0]))
if err := proxy.SetAPIServerPort(ctx, controlConfig.HTTPSPort, isIPv6); err != nil {
if err := proxy.SetAPIServerPort(controlConfig.HTTPSPort, isIPv6); err != nil {
return nil, errors.Wrapf(err, "failed to setup access to API Server port %d on at %s", controlConfig.HTTPSPort, proxy.SupervisorURL())
}
}
Expand Down
35 changes: 13 additions & 22 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import (

// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}

Expand All @@ -31,7 +34,9 @@ type serverConn struct {
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
mutex sync.Mutex
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy

serviceName string
Expand Down Expand Up @@ -123,26 +128,9 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)

return lb, nil
}

func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})}
}
go lb.runHealthChecks(ctx)

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
return lb, nil
}

func (lb *LoadBalancer) Update(serverAddresses []string) {
Expand All @@ -166,15 +154,18 @@ func (lb *LoadBalancer) LoadBalancerServerURL() string {
return lb.localServerURL
}

func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()

startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress

server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else {
} else if server.healthCheck() {
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
Expand Down
74 changes: 67 additions & 7 deletions pkg/agent/loadbalancer/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"os"
"strconv"
"time"

"github.com/k3s-io/k3s/pkg/version"
http_dialer "github.com/mwitkow/go-http-dialer"
Expand All @@ -17,6 +18,7 @@ import (

"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
)

var defaultDialer proxy.Dialer = &net.Dialer{}
Expand Down Expand Up @@ -73,7 +75,11 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {

for addedServer := range newAddresses.Difference(curAddresses) {
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
lb.servers[addedServer] = &server{connections: make(map[net.Conn]struct{})}
lb.servers[addedServer] = &server{
address: addedServer,
connections: make(map[net.Conn]struct{}),
healthCheck: func() bool { return true },
}
}

for removedServer := range curAddresses.Difference(newAddresses) {
Expand Down Expand Up @@ -106,8 +112,8 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
}

func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
lb.mutex.RLock()
defer lb.mutex.RUnlock()

if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
Expand Down Expand Up @@ -162,10 +168,12 @@ func (s *server) closeAll() {
s.mutex.Lock()
defer s.mutex.Unlock()

logrus.Debugf("Closing %d connections to load balancer server", len(s.connections))
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
}
}
}

Expand All @@ -178,3 +186,55 @@ func (sc *serverConn) Close() error {
delete(sc.server.connections, sc)
return sc.Conn.Close()
}

// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return true },
connections: make(map[net.Conn]struct{}),
}
}

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}

// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}

// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
wait.Until(func() {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
for _, server := range lb.servers {
if !server.healthCheck() {
defer server.closeAll()
}
}
}, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
}
26 changes: 22 additions & 4 deletions pkg/agent/proxy/apiproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"context"
"net"
sysnet "net"
"net/url"
"strconv"
Expand All @@ -14,13 +15,14 @@ import (

type Proxy interface {
Update(addresses []string)
SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) error
SetAPIServerPort(port int, isIPv6 bool) error
SetSupervisorDefault(address string)
IsSupervisorLBEnabled() bool
SupervisorURL() string
SupervisorAddresses() []string
APIServerURL() string
IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool)
}

// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
Expand All @@ -38,6 +40,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
supervisorURL: supervisorURL,
apiServerURL: supervisorURL,
lbServerPort: lbServerPort,
context: ctx,
}

if lbEnabled {
Expand Down Expand Up @@ -70,6 +73,7 @@ type proxy struct {
apiServerEnabled bool

apiServerURL string
apiServerPort string
supervisorURL string
supervisorPort string
initialSupervisorURL string
Expand All @@ -78,6 +82,7 @@ type proxy struct {

apiServerLB *loadbalancer.LoadBalancer
supervisorLB *loadbalancer.LoadBalancer
context context.Context
}

func (p *proxy) Update(addresses []string) {
Expand All @@ -96,6 +101,18 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses
}

func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck)
}

if p.apiServerLB != nil {
host, _, _ := net.SplitHostPort(address)
address = net.JoinHostPort(host, p.apiServerPort)
p.apiServerLB.SetHealthCheck(address, healthCheck)
}
}

func (p *proxy) setSupervisorPort(addresses []string) []string {
var newAddresses []string
for _, address := range addresses {
Expand All @@ -114,12 +131,13 @@ func (p *proxy) setSupervisorPort(addresses []string) []string {
// load-balancing is enabled, another load-balancer is started on a port one below the supervisor
// load-balancer, and the address of this load-balancer is returned instead of the actual apiserver
// addresses.
func (p *proxy) SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) error {
func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
u, err := url.Parse(p.initialSupervisorURL)
if err != nil {
return errors.Wrapf(err, "failed to parse server URL %s", p.initialSupervisorURL)
}
u.Host = sysnet.JoinHostPort(u.Hostname(), strconv.Itoa(port))
p.apiServerPort = strconv.Itoa(port)
u.Host = sysnet.JoinHostPort(u.Hostname(), p.apiServerPort)

p.apiServerURL = u.String()
p.apiServerEnabled = true
Expand All @@ -129,7 +147,7 @@ func (p *proxy) SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) err
if lbServerPort != 0 {
lbServerPort = lbServerPort - 1
}
lb, err := loadbalancer.New(ctx, p.dataDir, loadbalancer.APIServerServiceName, p.apiServerURL, lbServerPort, isIPv6)
lb, err := loadbalancer.New(p.context, p.dataDir, loadbalancer.APIServerServiceName, p.apiServerURL, lbServerPort, isIPv6)
if err != nil {
return err
}
Expand Down
56 changes: 43 additions & 13 deletions pkg/agent/tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tunnel
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -289,7 +290,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
disconnect := map[string]context.CancelFunc{}
for _, address := range proxy.SupervisorAddresses() {
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, wg, address, tlsConfig)
conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -361,7 +364,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
for _, address := range proxy.SupervisorAddresses() {
validEndpoint[address] = true
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, nil, address, tlsConfig)
conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -403,32 +408,54 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
return false
}

type agentConnection struct {
cancel context.CancelFunc
connected func() bool
}

// connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) context.CancelFunc {
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}

// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true

once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
}

ctx, cancel := context.WithCancel(rootCtx)
auth := func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}

onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
}

// Start remotedialer connect loop in a goroutine to ensure a connection to the target server
go func() {
for {
remotedialer.ClientConnect(ctx, wsURL, nil, ws, func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}, func(_ context.Context, _ *remotedialer.Session) error {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
})

// ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxy(ctx, wsURL, nil, auth, ws, onConnect)
connected = false
if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconecting...")
// wait between reconnection attempts to avoid hammering the server
time.Sleep(endpointDebounceDelay)
}
// If the context has been cancelled, exit the goroutine instead of retrying
if ctx.Err() != nil {
if waitGroup != nil {
once.Do(waitGroup.Done)
Expand All @@ -438,7 +465,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}
}()

return cancel
return agentConnection{
cancel: cancel,
connected: func() bool { return connected },
}
}

// isKubeletPort returns true if the connection is to a reserved TCP port on a loopback address.
Expand Down
Loading