Skip to content

Commit

Permalink
allow custom dialer for multiple hosts to be randomized
Browse files Browse the repository at this point in the history
  • Loading branch information
harshavardhana committed Mar 9, 2020
1 parent 8485827 commit 6962865
Showing 1 changed file with 55 additions and 27 deletions.
82 changes: 55 additions & 27 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package main

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httputil"
Expand All @@ -37,6 +39,7 @@ import (
type Backend struct {
endpoint string
proxy *httputil.ReverseProxy
httpClient *http.Client
up bool
healthCheckPath string
healthCheckDuration int
Expand All @@ -57,33 +60,45 @@ func (b *Backend) ErrorHandler(w http.ResponseWriter, r *http.Request, err error
func (b *Backend) healthCheck() {
healthCheckURL := b.endpoint + b.healthCheckPath
for {
resp, err := http.Get(healthCheckURL)
req, err := http.NewRequest(http.MethodGet, healthCheckURL, nil)
if err != nil {
if b.logging {
fmt.Printf("%s %s fails\n", b.endpoint, err)
}
b.up = false
time.Sleep(time.Duration(b.healthCheckDuration) * time.Second)
continue
}

resp, err := b.httpClient.Do(req)
switch {
case err == nil && b.healthCheckPath == "":
resp.Body.Close()
fallthrough
case err == nil && resp.StatusCode == http.StatusOK:
resp.Body.Close()
if b.logging {
fmt.Printf("%s is up\n", b.endpoint)
}
b.up = true
default:
if b.logging {
fmt.Printf("%s is down : %s\n", b.endpoint, err.Error())
fmt.Printf("%s is down : %s\n", b.endpoint, err)
}
b.up = false
}
time.Sleep(time.Duration(b.healthCheckDuration) * time.Second)
}
}

type LoadBalancer struct {
type loadBalancer struct {
backends []*Backend
next int // next backend the request should go to.
sync.RWMutex
}

// Returns the next backend the request should go to.
func (lb *LoadBalancer) nextProxy() *httputil.ReverseProxy {
func (lb *loadBalancer) nextProxy() *httputil.ReverseProxy {
lb.Lock()
defer lb.Unlock()

Expand All @@ -109,7 +124,7 @@ func (lb *LoadBalancer) nextProxy() *httputil.ReverseProxy {
}

// ServeHTTP - LoadBalancer implements http.Handler
func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (lb *loadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
proxy := lb.nextProxy()
if proxy == nil {
w.WriteHeader(http.StatusBadGateway)
Expand All @@ -127,13 +142,39 @@ func mustGetSystemCertPool() *x509.CertPool {
return pool
}

var rng = rand.New(rand.NewSource(time.Now().UTC().UnixNano()))

type dialContext func(ctx context.Context, network, address string) (net.Conn, error)

func newCustomDialContext(dialTimeout, dialKeepAlive time.Duration) dialContext {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: dialTimeout,
KeepAlive: dialKeepAlive,
}

host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

addrs, err := net.LookupHost(host)
if err != nil {
addrs = []string{host}
}

for i := range addrs {
addrs[i] = net.JoinHostPort(addrs[i], port)
}

return dialer.DialContext(ctx, network, addrs[rng.Intn(len(addrs))])
}
}

func clientTransport(ctx *cli.Context, enableTLS bool) http.RoundTripper {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 5 * time.Second,
}).DialContext,
Proxy: http.ProxyFromEnvironment,
DialContext: newCustomDialContext(5*time.Second, 5*time.Second),
MaxIdleConnsPerHost: 256,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
Expand Down Expand Up @@ -201,21 +242,6 @@ func sidekickMain(ctx *cli.Context) {
}
} else {
endpoints = ctx.Args()
if len(endpoints) == 1 {
target, err := url.Parse(endpoints[0])
if err != nil {
console.Fatalln(fmt.Errorf("Unable to parse input arg %s: %s", endpoints[0], err))
}
// Single endpoint do lookup address to get all IPs
addrs, err := net.LookupHost(target.Hostname())
if err != nil {
console.Fatalln(fmt.Errorf("Unable to lookup host %s: %s", endpoints[0], err))
}
endpoints = make([]string, len(addrs))
for i, addr := range addrs {
endpoints[i] = target.Scheme + "://" + net.JoinHostPort(addr, target.Port())
}
}
}

var backends []*Backend
Expand All @@ -238,13 +264,15 @@ func sidekickMain(ctx *cli.Context) {
}
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.Transport = clientTransport(ctx, target.Scheme == "https")
backend := &Backend{endpoint, proxy, false, healthCheckPath, healthCheckDuration, logging}
backend := &Backend{endpoint, proxy, &http.Client{
Transport: proxy.Transport,
}, false, healthCheckPath, healthCheckDuration, logging}
go backend.healthCheck()
proxy.ErrorHandler = backend.ErrorHandler
backends = append(backends, backend)
}
console.Infoln("Listening on", addr)
if err := http.ListenAndServe(addr, &LoadBalancer{
if err := http.ListenAndServe(addr, &loadBalancer{
backends: backends,
}); err != nil {
console.Fatalln(err)
Expand Down

0 comments on commit 6962865

Please sign in to comment.