diff --git a/roundrobin/rebalancer.go b/roundrobin/rebalancer.go index d667ec11..6dee9bf9 100644 --- a/roundrobin/rebalancer.go +++ b/roundrobin/rebalancer.go @@ -48,6 +48,9 @@ type Rebalancer struct { // creates new meters newMeter NewMeterFn + // sticky session object + stickySession *StickySession + requestRewriteListener RequestRewriteListener } @@ -80,6 +83,13 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { } } +func RebalancerStickySession(stickySession *StickySession) RebalancerOption { + return func(r *Rebalancer) error { + r.stickySession = stickySession + return nil + } +} + // RebalancerErrorHandler is a functional argument that sets error handler of the server func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { return func(r *Rebalancer) error { @@ -90,8 +100,9 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { rb := &Rebalancer{ - mtx: &sync.Mutex{}, - next: handler, + mtx: &sync.Mutex{}, + next: handler, + stickySession: nil, } for _, o := range opts { if err := o(rb); err != nil { @@ -139,20 +150,42 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { pw := &utils.ProxyWriter{W: w} start := rb.clock.UtcNow() - url, err := rb.next.NextServer() - if err != nil { - rb.errHandler.ServeHTTP(w, req, err) - return - } - - if log.GetLevel() >= log.DebugLevel { - //log which backend URL we're sending this request to - log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") - } // make shallow copy of request before changing anything to avoid side effects newReq := *req - newReq.URL = url + stuck := false + + if rb.stickySession != nil { + cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers()) + + if err != nil { + log.Infof("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err) + } + + if present { + newReq.URL = cookieUrl + stuck = true + } + } + + if !stuck { + url, err := rb.next.NextServer() + if err != nil { + rb.errHandler.ServeHTTP(w, req, err) + return + } + + if log.GetLevel() >= log.DebugLevel { + //log which backend URL we're sending this request to + log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") + } + + if rb.stickySession != nil { + rb.stickySession.StickBackend(url, &w) + } + + newReq.URL = url + } //Emit event to a listener if one exists if rb.requestRewriteListener != nil { @@ -161,7 +194,7 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { rb.next.Next().ServeHTTP(pw, &newReq) - rb.recordMetrics(url, pw.Code, rb.clock.UtcNow().Sub(start)) + rb.recordMetrics(newReq.URL, pw.Code, rb.clock.UtcNow().Sub(start)) rb.adjustWeights() } @@ -244,11 +277,11 @@ func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error { return nil } -func (r *Rebalancer) findServer(u *url.URL) (*rbServer, int) { - if len(r.servers) == 0 { +func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) { + if len(rb.servers) == 0 { return nil, -1 } - for i, s := range r.servers { + for i, s := range rb.servers { if sameURL(u, s.url) { return s, i } @@ -351,7 +384,7 @@ func (rb *Rebalancer) markServers() bool { } func (rb *Rebalancer) convergeWeights() bool { - // If we have previoulsy changed servers try to restore weights to the original state + // If we have previously changed servers try to restore weights to the original state changed := false for _, s := range rb.servers { if s.origWeight == s.curWeight { diff --git a/roundrobin/rebalancer_test.go b/roundrobin/rebalancer_test.go index c654819e..46bde460 100644 --- a/roundrobin/rebalancer_test.go +++ b/roundrobin/rebalancer_test.go @@ -1,6 +1,7 @@ package roundrobin import ( + "io/ioutil" "net/http" "net/http/httptest" "time" @@ -8,7 +9,6 @@ import ( "github.com/mailgun/timetools" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" - . "gopkg.in/check.v1" ) @@ -339,6 +339,52 @@ func (s *RBSuite) TestRequestRewriteListener(c *C) { c.Assert(rb.requestRewriteListener, NotNil) } +func (s *RBSuite) TestRebalancerStickySession(c *C) { + a, b, x := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("x") + defer a.Close() + defer b.Close() + defer x.Close() + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + fwd, err := forward.New() + c.Assert(err, IsNil) + + lb, err := New(fwd) + c.Assert(err, IsNil) + + rb, err := NewRebalancer(lb, RebalancerStickySession(sticky)) + c.Assert(err, IsNil) + + rb.UpsertServer(testutils.ParseURI(a.URL)) + rb.UpsertServer(testutils.ParseURI(b.URL)) + rb.UpsertServer(testutils.ParseURI(x.URL)) + + proxy := httptest.NewServer(rb) + defer proxy.Close() + + for i := 0; i < 10; i++ { + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + + resp, err := http.DefaultClient.Do(req) + c.Assert(err, IsNil) + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + + c.Assert(err, IsNil) + c.Assert(string(body), Equals, "a") + } + + c.Assert(rb.RemoveServer(testutils.ParseURI(a.URL)), IsNil) + c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"b", "x", "b"}) + c.Assert(rb.RemoveServer(testutils.ParseURI(b.URL)), IsNil) + c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"x", "x", "x"}) +} + type testMeter struct { rating float64 notReady bool diff --git a/roundrobin/rr.go b/roundrobin/rr.go index 951ae69b..d52b1471 100644 --- a/roundrobin/rr.go +++ b/roundrobin/rr.go @@ -30,6 +30,13 @@ func ErrorHandler(h utils.ErrorHandler) LBOption { } } +func EnableStickySession(stickySession *StickySession) LBOption { + return func(s *RoundRobin) error { + s.stickySession = stickySession + return nil + } +} + // ErrorHandler is a functional argument that sets error handler of the server func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { return func(s *RoundRobin) error { @@ -46,15 +53,17 @@ type RoundRobin struct { index int servers []*server currentWeight int + stickySession *StickySession requestRewriteListener RequestRewriteListener } func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { rr := &RoundRobin{ - next: next, - index: -1, - mutex: &sync.Mutex{}, - servers: []*server{}, + next: next, + index: -1, + mutex: &sync.Mutex{}, + servers: []*server{}, + stickySession: nil, } for _, o := range opts { if err := o(rr); err != nil { @@ -78,21 +87,40 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { defer logEntry.Debug("vulcand/oxy/roundrobin/rr: competed ServeHttp on request") } - url, err := r.NextServer() - if err != nil { - r.errHandler.ServeHTTP(w, req, err) - return + // make shallow copy of request before chaning anything to avoid side effects + newReq := *req + stuck := false + if r.stickySession != nil { + cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers()) + + if err != nil { + log.Infof("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err) + } + + if present { + newReq.URL = cookieURL + stuck = true + } + } + + if !stuck { + url, err := r.NextServer() + if err != nil { + r.errHandler.ServeHTTP(w, req, err) + return + } + + if r.stickySession != nil { + r.stickySession.StickBackend(url, &w) + } + newReq.URL = url } if log.GetLevel() >= log.DebugLevel { //log which backend URL we're sending this request to - log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") + log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") } - // make shallow copy of request before chaning anything to avoid side effects - newReq := *req - newReq.URL = url - //Emit event to a listener if one exists if r.requestRewriteListener != nil { r.requestRewriteListener(req, &newReq) diff --git a/roundrobin/stickysessions.go b/roundrobin/stickysessions.go new file mode 100644 index 00000000..3fabeb97 --- /dev/null +++ b/roundrobin/stickysessions.go @@ -0,0 +1,56 @@ +// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity +package roundrobin + +import ( + "net/http" + "net/url" +) + +type StickySession struct { + cookieName string +} + +func NewStickySession(cookieName string) *StickySession { + return &StickySession{cookieName} +} + +// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. +func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { + cookie, err := req.Cookie(s.cookieName) + switch err { + case nil: + case http.ErrNoCookie: + return nil, false, nil + default: + return nil, false, err + } + + serverURL, err := url.Parse(cookie.Value) + if err != nil { + return nil, false, err + } + + if s.isBackendAlive(serverURL, servers) { + return serverURL, true, nil + } else { + return nil, false, nil + } +} + +func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { + cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"} + http.SetCookie(*w, cookie) +} + +func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool { + if len(haystack) == 0 { + return false + } + + for _, serverURL := range haystack { + if sameURL(needle, serverURL) { + return true + } + } + return false +} diff --git a/roundrobin/stickysessions_test.go b/roundrobin/stickysessions_test.go new file mode 100644 index 00000000..08f49391 --- /dev/null +++ b/roundrobin/stickysessions_test.go @@ -0,0 +1,262 @@ +package roundrobin + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/testutils" + + . "gopkg.in/check.v1" +) + +func TestStickySession(t *testing.T) { TestingT(t) } + +type StickySessionSuite struct{} + +var _ = Suite(&StickySessionSuite{}) + +func (s *StickySessionSuite) TestBasic(c *C) { + a := testutils.NewResponder("a") + b := testutils.NewResponder("b") + + defer a.Close() + defer b.Close() + + fwd, err := forward.New() + c.Assert(err, IsNil) + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + lb, err := New(fwd, EnableStickySession(sticky)) + c.Assert(err, IsNil) + + err = lb.UpsertServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + err = lb.UpsertServer(testutils.ParseURI(b.URL)) + c.Assert(err, IsNil) + + proxy := httptest.NewServer(lb) + defer proxy.Close() + + client := http.DefaultClient + + for i := 0; i < 10; i++ { + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + + resp, err := client.Do(req) + c.Assert(err, IsNil) + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + + c.Assert(err, IsNil) + c.Assert(string(body), Equals, "a") + } +} + +func (s *StickySessionSuite) TestStickCookie(c *C) { + a := testutils.NewResponder("a") + b := testutils.NewResponder("b") + + defer a.Close() + defer b.Close() + + fwd, err := forward.New() + c.Assert(err, IsNil) + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + lb, err := New(fwd, EnableStickySession(sticky)) + c.Assert(err, IsNil) + + err = lb.UpsertServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + err = lb.UpsertServer(testutils.ParseURI(b.URL)) + c.Assert(err, IsNil) + + proxy := httptest.NewServer(lb) + defer proxy.Close() + + resp, err := http.Get(proxy.URL) + c.Assert(err, IsNil) + + cookie := resp.Cookies()[0] + c.Assert(cookie.Name, Equals, "test") + c.Assert(cookie.Value, Equals, a.URL) +} + +func (s *StickySessionSuite) TestRemoveRespondingServer(c *C) { + a := testutils.NewResponder("a") + b := testutils.NewResponder("b") + + defer a.Close() + defer b.Close() + + fwd, err := forward.New() + c.Assert(err, IsNil) + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + lb, err := New(fwd, EnableStickySession(sticky)) + c.Assert(err, IsNil) + + err = lb.UpsertServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + err = lb.UpsertServer(testutils.ParseURI(b.URL)) + c.Assert(err, IsNil) + + proxy := httptest.NewServer(lb) + defer proxy.Close() + + client := http.DefaultClient + + for i := 0; i < 10; i++ { + req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(errReq, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + + resp, errReq := client.Do(req) + c.Assert(errReq, IsNil) + + defer resp.Body.Close() + body, errReq := ioutil.ReadAll(resp.Body) + + c.Assert(errReq, IsNil) + c.Assert(string(body), Equals, "a") + } + + err = lb.RemoveServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + + // Now, use the organic cookie response in our next requests. + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + resp, err := client.Do(req) + c.Assert(err, IsNil) + + c.Assert(resp.Cookies()[0].Name, Equals, "test") + c.Assert(resp.Cookies()[0].Value, Equals, b.URL) + + for i := 0; i < 10; i++ { + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + + resp, err := client.Do(req) + c.Assert(err, IsNil) + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + + c.Assert(err, IsNil) + c.Assert(string(body), Equals, "b") + } +} + +func (s *StickySessionSuite) TestRemoveAllServers(c *C) { + a := testutils.NewResponder("a") + b := testutils.NewResponder("b") + + defer a.Close() + defer b.Close() + + fwd, err := forward.New() + c.Assert(err, IsNil) + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + lb, err := New(fwd, EnableStickySession(sticky)) + c.Assert(err, IsNil) + + err = lb.UpsertServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + err = lb.UpsertServer(testutils.ParseURI(b.URL)) + c.Assert(err, IsNil) + + proxy := httptest.NewServer(lb) + defer proxy.Close() + + client := http.DefaultClient + + for i := 0; i < 10; i++ { + req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(errReq, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + + resp, errReq := client.Do(req) + c.Assert(errReq, IsNil) + + defer resp.Body.Close() + body, errReq := ioutil.ReadAll(resp.Body) + + c.Assert(errReq, IsNil) + c.Assert(string(body), Equals, "a") + } + + err = lb.RemoveServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + err = lb.RemoveServer(testutils.ParseURI(b.URL)) + c.Assert(err, IsNil) + + // Now, use the organic cookie response in our next requests. + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) + resp, err := client.Do(req) + c.Assert(err, IsNil) + c.Assert(resp.StatusCode, Equals, http.StatusInternalServerError) +} + +func (s *StickySessionSuite) TestBadCookieVal(c *C) { + a := testutils.NewResponder("a") + + defer a.Close() + + fwd, err := forward.New() + c.Assert(err, IsNil) + + sticky := NewStickySession("test") + c.Assert(sticky, NotNil) + + lb, err := New(fwd, EnableStickySession(sticky)) + c.Assert(err, IsNil) + + err = lb.UpsertServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + + proxy := httptest.NewServer(lb) + defer proxy.Close() + + client := http.DefaultClient + + req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) + c.Assert(err, IsNil) + req.AddCookie(&http.Cookie{Name: "test", Value: "This is a patently invalid url! You can't parse it! :-)"}) + + resp, err := client.Do(req) + c.Assert(err, IsNil) + + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Assert(string(body), Equals, "a") + + // Now, cycle off the good server to cause an error + err = lb.RemoveServer(testutils.ParseURI(a.URL)) + c.Assert(err, IsNil) + + resp, err = client.Do(req) + c.Assert(err, IsNil) + + _, err = ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Assert(resp.StatusCode, Equals, http.StatusInternalServerError) +}