Skip to content

Commit

Permalink
Add sticky session support (#97)
Browse files Browse the repository at this point in the history
* qnd cookie hacking

* Add work in progress sticky session support

Sticky sessions are set through an HTTP cookie.
If the cookie:
    * is not present, use the next server & set that as sticky
    * is present,
        * but is no longer valid, use the next server & set that as
          sticky
        * and valid, use that server without advancing .next.

* fix misleading comment

layer 7, that is... layer 8 is something different (https://en.wikipedia.org/wiki/Layer_8)

* Add sticky session support to rebalancing rr

* Fix incorrect test to match the actual sticky session actions

* Setting the cookie path to '/'

* Format and imports

* Fix log

* Fix error when invalid cookie

* fix variables name

* refactor: code review.
  • Loading branch information
juliens authored and emilevauge committed Dec 8, 2017
1 parent 5dfac99 commit 856ab51
Show file tree
Hide file tree
Showing 5 changed files with 457 additions and 32 deletions.
69 changes: 51 additions & 18 deletions roundrobin/rebalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type Rebalancer struct {
// creates new meters
newMeter NewMeterFn

// sticky session object
stickySession *StickySession

requestRewriteListener RequestRewriteListener
}

Expand Down Expand Up @@ -80,6 +83,13 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption {
}
}

func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
return func(r *Rebalancer) error {
r.stickySession = stickySession
return nil
}
}

// RebalancerErrorHandler is a functional argument that sets error handler of the server
func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption {
return func(r *Rebalancer) error {
Expand All @@ -90,8 +100,9 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti

func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) {
rb := &Rebalancer{
mtx: &sync.Mutex{},
next: handler,
mtx: &sync.Mutex{},
next: handler,
stickySession: nil,
}
for _, o := range opts {
if err := o(rb); err != nil {
Expand Down Expand Up @@ -139,20 +150,42 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {

pw := &utils.ProxyWriter{W: w}
start := rb.clock.UtcNow()
url, err := rb.next.NextServer()
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}

// make shallow copy of request before changing anything to avoid side effects
newReq := *req
newReq.URL = url
stuck := false

if rb.stickySession != nil {
cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers())

if err != nil {
log.Infof("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err)
}

if present {
newReq.URL = cookieUrl
stuck = true
}
}

if !stuck {
url, err := rb.next.NextServer()
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}

if rb.stickySession != nil {
rb.stickySession.StickBackend(url, &w)
}

newReq.URL = url
}

//Emit event to a listener if one exists
if rb.requestRewriteListener != nil {
Expand All @@ -161,7 +194,7 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {

rb.next.Next().ServeHTTP(pw, &newReq)

rb.recordMetrics(url, pw.Code, rb.clock.UtcNow().Sub(start))
rb.recordMetrics(newReq.URL, pw.Code, rb.clock.UtcNow().Sub(start))
rb.adjustWeights()
}

Expand Down Expand Up @@ -244,11 +277,11 @@ func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error {
return nil
}

func (r *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(r.servers) == 0 {
func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(rb.servers) == 0 {
return nil, -1
}
for i, s := range r.servers {
for i, s := range rb.servers {
if sameURL(u, s.url) {
return s, i
}
Expand Down Expand Up @@ -351,7 +384,7 @@ func (rb *Rebalancer) markServers() bool {
}

func (rb *Rebalancer) convergeWeights() bool {
// If we have previoulsy changed servers try to restore weights to the original state
// If we have previously changed servers try to restore weights to the original state
changed := false
for _, s := range rb.servers {
if s.origWeight == s.curWeight {
Expand Down
48 changes: 47 additions & 1 deletion roundrobin/rebalancer_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package roundrobin

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"time"

"github.com/mailgun/timetools"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/testutils"

. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -339,6 +339,52 @@ func (s *RBSuite) TestRequestRewriteListener(c *C) {
c.Assert(rb.requestRewriteListener, NotNil)
}

func (s *RBSuite) TestRebalancerStickySession(c *C) {
a, b, x := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("x")
defer a.Close()
defer b.Close()
defer x.Close()

sticky := NewStickySession("test")
c.Assert(sticky, NotNil)

fwd, err := forward.New()
c.Assert(err, IsNil)

lb, err := New(fwd)
c.Assert(err, IsNil)

rb, err := NewRebalancer(lb, RebalancerStickySession(sticky))
c.Assert(err, IsNil)

rb.UpsertServer(testutils.ParseURI(a.URL))
rb.UpsertServer(testutils.ParseURI(b.URL))
rb.UpsertServer(testutils.ParseURI(x.URL))

proxy := httptest.NewServer(rb)
defer proxy.Close()

for i := 0; i < 10; i++ {
req, err := http.NewRequest(http.MethodGet, proxy.URL, nil)
c.Assert(err, IsNil)
req.AddCookie(&http.Cookie{Name: "test", Value: a.URL})

resp, err := http.DefaultClient.Do(req)
c.Assert(err, IsNil)

defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)

c.Assert(err, IsNil)
c.Assert(string(body), Equals, "a")
}

c.Assert(rb.RemoveServer(testutils.ParseURI(a.URL)), IsNil)
c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"b", "x", "b"})
c.Assert(rb.RemoveServer(testutils.ParseURI(b.URL)), IsNil)
c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"x", "x", "x"})
}

type testMeter struct {
rating float64
notReady bool
Expand Down
54 changes: 41 additions & 13 deletions roundrobin/rr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ func ErrorHandler(h utils.ErrorHandler) LBOption {
}
}

func EnableStickySession(stickySession *StickySession) LBOption {
return func(s *RoundRobin) error {
s.stickySession = stickySession
return nil
}
}

// ErrorHandler is a functional argument that sets error handler of the server
func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
return func(s *RoundRobin) error {
Expand All @@ -46,15 +53,17 @@ type RoundRobin struct {
index int
servers []*server
currentWeight int
stickySession *StickySession
requestRewriteListener RequestRewriteListener
}

func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
rr := &RoundRobin{
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
stickySession: nil,
}
for _, o := range opts {
if err := o(rr); err != nil {
Expand All @@ -78,21 +87,40 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer logEntry.Debug("vulcand/oxy/roundrobin/rr: competed ServeHttp on request")
}

url, err := r.NextServer()
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
// make shallow copy of request before chaning anything to avoid side effects
newReq := *req
stuck := false
if r.stickySession != nil {
cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers())

if err != nil {
log.Infof("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err)
}

if present {
newReq.URL = cookieURL
stuck = true
}
}

if !stuck {
url, err := r.NextServer()
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
}

if r.stickySession != nil {
r.stickySession.StickBackend(url, &w)
}
newReq.URL = url
}

if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
}

// make shallow copy of request before chaning anything to avoid side effects
newReq := *req
newReq.URL = url

//Emit event to a listener if one exists
if r.requestRewriteListener != nil {
r.requestRewriteListener(req, &newReq)
Expand Down
56 changes: 56 additions & 0 deletions roundrobin/stickysessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
package roundrobin

import (
"net/http"
"net/url"
)

type StickySession struct {
cookieName string
}

func NewStickySession(cookieName string) *StickySession {
return &StickySession{cookieName}
}

// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers.
func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) {
cookie, err := req.Cookie(s.cookieName)
switch err {
case nil:
case http.ErrNoCookie:
return nil, false, nil
default:
return nil, false, err
}

serverURL, err := url.Parse(cookie.Value)
if err != nil {
return nil, false, err
}

if s.isBackendAlive(serverURL, servers) {
return serverURL, true, nil
} else {
return nil, false, nil
}
}

func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) {
cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"}
http.SetCookie(*w, cookie)
}

func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool {
if len(haystack) == 0 {
return false
}

for _, serverURL := range haystack {
if sameURL(needle, serverURL) {
return true
}
}
return false
}
Loading

0 comments on commit 856ab51

Please sign in to comment.