From 0a58416c23f774d9957d9787ad494c85da2cec95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hunyadv=C3=A1ri=20P=C3=A9ter?= Date: Tue, 26 Sep 2023 10:42:01 +0200 Subject: [PATCH 1/2] feature: ability to periodically check the service registration in the consul server --- v4/registry/consul/consul.go | 93 ++++++----- v4/registry/consul/options.go | 11 ++ v4/registry/consul/registry_test.go | 239 ++++++++++++++++++++++++++-- v4/registry/consul/watcher_test.go | 1 + 4 files changed, 290 insertions(+), 54 deletions(-) diff --git a/v4/registry/consul/consul.go b/v4/registry/consul/consul.go index cdada28f..cb3183bb 100644 --- a/v4/registry/consul/consul.go +++ b/v4/registry/consul/consul.go @@ -168,6 +168,7 @@ func (c *consulRegistry) Deregister(s *registry.Service, opts ...registry.Deregi c.Unlock() node := s.Nodes[0] + return c.Client().Agent().ServiceDeregister(node.Id) } @@ -176,12 +177,13 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register return errors.New("Require at least one node") } - var regTCPCheck bool - var regInterval time.Duration - var regHTTPCheck bool - var httpCheckConfig consul.AgentServiceCheck - - var options registry.RegisterOptions + var ( + regTCPCheck bool + regInterval time.Duration + regHTTPCheck bool + httpCheckConfig consul.AgentServiceCheck + options registry.RegisterOptions + ) for _, o := range opts { o(&options) } @@ -191,9 +193,9 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register regTCPCheck = true regInterval = tcpCheckInterval } - var ok bool - if httpCheckConfig, ok = c.opts.Context.Value("consul_http_check_config").(consul.AgentServiceCheck); ok { + if conf, ok := c.opts.Context.Value("consul_http_check_config").(consul.AgentServiceCheck); ok { regHTTPCheck = true + httpCheckConfig = conf } } @@ -212,37 +214,8 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register lastChecked := c.lastChecked[s.Name] c.Unlock() - // if it's already registered and matches then just pass the check - if ok && v == h { - if options.TTL == time.Duration(0) { - // ensure that our service hasn't been deregistered by Consul - if time.Since(lastChecked) <= getDeregisterTTL(regInterval) { - return nil - } - services, _, err := c.Client().Health().Checks(s.Name, c.queryOptions) - if err == nil { - for _, v := range services { - if v.ServiceID == node.Id { - return nil - } - } - } - } else { - // if the err is nil we're all good, bail out - // if not, we don't know what the state is, so full re-register - if err := c.Client().Agent().PassTTL("service:"+node.Id, ""); err == nil { - return nil - } - } - } - - // encode the tags - tags := encodeMetadata(node.Metadata) - tags = append(tags, encodeEndpoints(s.Endpoints)...) - tags = append(tags, encodeVersion(s.Version)...) - var check *consul.AgentServiceCheck - + checkTTL := regInterval if regTCPCheck { deregTTL := getDeregisterTTL(regInterval) @@ -255,6 +228,7 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register } else if regHTTPCheck { interval, _ := time.ParseDuration(httpCheckConfig.Interval) deregTTL := getDeregisterTTL(interval) + checkTTL = interval host, _, _ := net.SplitHostPort(node.Address) healthCheckURI := strings.Replace(httpCheckConfig.HTTP, "{host}", host, 1) @@ -269,12 +243,53 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register // if the TTL is greater than 0 create an associated check } else if options.TTL > time.Duration(0) { deregTTL := getDeregisterTTL(options.TTL) + checkTTL = options.TTL check = &consul.AgentServiceCheck{ TTL: fmt.Sprintf("%v", options.TTL), DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL), } } + if c.opts.Context != nil { + if ttl, ok := c.opts.Context.Value("consul_check_ttl").(time.Duration); ok { + checkTTL = ttl + } + } + + // if it's already registered and matches then just pass the check + if ok && v == h { + passing := false + if time.Since(lastChecked) > checkTTL { + services, _, _ := c.Client().Health().Checks(s.Name, c.queryOptions) + for _, service := range services { + if service.ServiceID == node.Id && service.Status == "passing" { + passing = true + c.Lock() + c.lastChecked[s.Name] = time.Now() + c.Unlock() + break + } + } + } else { + passing = true + } + if passing { + if options.TTL == time.Duration(0) { + return nil + } + // if the err is nil we're all good, bail out + // if not, we don't know what the state is, so full re-register + if err := c.Client().Agent().UpdateTTL("service:"+node.Id, "", "pass"); err == nil { + return nil + } + } + c.Deregister(s) + } + + // encode the tags + tags := encodeMetadata(node.Metadata) + tags = append(tags, encodeEndpoints(s.Endpoints)...) + tags = append(tags, encodeVersion(s.Version)...) host, pt, _ := net.SplitHostPort(node.Address) if host == "" { @@ -316,7 +331,7 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register } // pass the healthcheck - return c.Client().Agent().PassTTL("service:"+node.Id, "") + return c.Client().Agent().UpdateTTL("service:"+node.Id, "", "pass") } func (c *consulRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { diff --git a/v4/registry/consul/options.go b/v4/registry/consul/options.go index 854313b9..82bae0af 100644 --- a/v4/registry/consul/options.go +++ b/v4/registry/consul/options.go @@ -28,6 +28,17 @@ func Config(c *consul.Config) registry.Option { } } +// CheckTTL allows you to periodically check the registration of the service to ensure +// that the registration actually exists in the consul +func CheckTTL(t time.Duration) registry.Option { + return func(o *registry.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, "consul_check_ttl", t) + } +} + // AllowStale sets whether any Consul server (non-leader) can service // a read. This allows for lower latency and higher throughput // at the cost of potentially stale data. diff --git a/v4/registry/consul/registry_test.go b/v4/registry/consul/registry_test.go index d18add4d..e9070e5d 100644 --- a/v4/registry/consul/registry_test.go +++ b/v4/registry/consul/registry_test.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/json" "errors" + "io" "net" "net/http" + "strings" "testing" "time" @@ -15,6 +17,7 @@ import ( type mockRegistry struct { body []byte + fn func(r *http.Request) ([]byte, int, error) status int err error url string @@ -29,20 +32,27 @@ func encodeData(obj interface{}) ([]byte, error) { return buf.Bytes(), nil } -func newMockServer(rg *mockRegistry, l net.Listener) error { +func newMockServer(l net.Listener, rgs ...*mockRegistry) error { mux := http.NewServeMux() - mux.HandleFunc(rg.url, func(w http.ResponseWriter, r *http.Request) { - if rg.err != nil { - http.Error(w, rg.err.Error(), 500) - return - } - w.WriteHeader(rg.status) - w.Write(rg.body) - }) + for _, rg := range rgs { + rgIn := rg + mux.HandleFunc(rg.url, func(w http.ResponseWriter, r *http.Request) { + body, status, err := rgIn.body, rgIn.status, rgIn.err + if rg.fn != nil { + body, status, err = rgIn.fn(r) + } + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.WriteHeader(status) + w.Write(body) + }) + } return http.Serve(l, mux) } -func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) { +func newConsulTestRegistry(chechkTTL time.Duration, r ...*mockRegistry) (*consulRegistry, func()) { l, err := net.Listen("tcp", "localhost:0") if err != nil { // blurgh?!! @@ -51,7 +61,7 @@ func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) { cfg := consul.DefaultConfig() cfg.Address = l.Addr().String() - go newMockServer(r, l) + go newMockServer(l, r...) var cr = &consulRegistry{ config: cfg, @@ -63,6 +73,7 @@ func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) { AllowStale: true, }, } + CheckTTL(time.Nanosecond)(&cr.opts) cr.Client() return cr, func() { @@ -76,7 +87,7 @@ func newServiceList(svc []*consul.ServiceEntry) []byte { } func TestConsul_GetService_WithError(t *testing.T) { - cr, cl := newConsulTestRegistry(&mockRegistry{ + cr, cl := newConsulTestRegistry(time.Second, &mockRegistry{ err: errors.New("client-error"), url: "/v1/health/service/service-name", }) @@ -106,7 +117,7 @@ func TestConsul_GetService_WithHealthyServiceNodes(t *testing.T) { ), } - cr, cl := newConsulTestRegistry(&mockRegistry{ + cr, cl := newConsulTestRegistry(time.Second, &mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", @@ -146,7 +157,7 @@ func TestConsul_GetService_WithUnhealthyServiceNode(t *testing.T) { ), } - cr, cl := newConsulTestRegistry(&mockRegistry{ + cr, cl := newConsulTestRegistry(time.Second, &mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", @@ -186,7 +197,7 @@ func TestConsul_GetService_WithUnhealthyServiceNodes(t *testing.T) { ), } - cr, cl := newConsulTestRegistry(&mockRegistry{ + cr, cl := newConsulTestRegistry(time.Second, &mockRegistry{ status: 200, body: newServiceList(svcs), url: "/v1/health/service/service-name", @@ -206,3 +217,201 @@ func TestConsul_GetService_WithUnhealthyServiceNodes(t *testing.T) { t.Fatalf("Expected len of nodes to be `%d`, got `%d`.", exp, act) } } + +func TestConsul_TestRegistrer(t *testing.T) { + registerCalled := 0 + cr, cl := newConsulTestRegistry( + time.Second, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"ID":"nodeId","Name":"service1","Tags":["v-789c010000ffff00000001"],` + + `"Address":"address","Check":{"TTL":"1s","DeregisterCriticalServiceAfter":"1m5s"},"Checks":null}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + registerCalled++ + return []byte(`{"success"":true}`), 200, nil + }, + url: "/v1/agent/service/register", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"Status":"passing","Output":""}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + return nil, 200, nil + }, + url: "/v1/agent/check/update/service:nodeId", + }, + ) + defer cl() + + service := ®istry.Service{ + Name: "service1", + Nodes: []*registry.Node{ + { + Address: "address", + Id: "nodeId", + }, + }, + } + rOpts := []registry.RegisterOption{registry.RegisterTTL(time.Second)} + err := cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + err = cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + if registerCalled >= 1 { + t.Fatalf("Expected run time to be %d`, got `%d`.", 1, registerCalled) + } +} + +func TestConsul_TestRegistrerWithCheck(t *testing.T) { + registerCalled := 0 + cr, cl := newConsulTestRegistry( + time.Nanosecond, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"ID":"nodeId","Name":"service1","Tags":["v-789c010000ffff00000001"],` + + `"Address":"address","Check":{"TTL":"1s","DeregisterCriticalServiceAfter":"1m5s"},"Checks":null}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + registerCalled++ + return []byte(`{"success"":true}`), 200, nil + }, + url: "/v1/agent/service/register", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"Status":"passing","Output":""}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + return nil, 200, nil + }, + url: "/v1/agent/check/update/service:nodeId", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := encodeData([]*consul.HealthCheck{ + newHealthCheck("nodeId", "service1", "passing"), + }) + return b, 200, nil + }, + url: "/v1/health/checks/service1", + }, + ) + defer cl() + + service := ®istry.Service{ + Name: "service1", + Nodes: []*registry.Node{ + { + Address: "address", + Id: "nodeId", + }, + }, + } + rOpts := []registry.RegisterOption{registry.RegisterTTL(time.Second)} + err := cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + err = cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + + if registerCalled >= 1 { + t.Fatalf("Expected run time to be %d`, got `%d`.", 1, registerCalled) + } +} + +func TestConsul_TestRegistrerWithFailedCheck(t *testing.T) { + registerCalled := 0 + deregisterCalled := 0 + cr, cl := newConsulTestRegistry( + time.Nanosecond, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"ID":"nodeId","Name":"service1","Tags":["v-789c010000ffff00000001"],` + + `"Address":"address","Check":{"TTL":"1s","DeregisterCriticalServiceAfter":"1m5s"},"Checks":null}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + registerCalled++ + return []byte(`{"success"":true}`), 200, nil + }, + url: "/v1/agent/service/register", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + deregisterCalled++ + return []byte(`{"success"":true}`), 200, nil + }, + url: "/v1/agent/service/deregister/nodeId", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := io.ReadAll(r.Body) + exp := `{"Status":"passing","Output":""}` + body := strings.TrimSpace(string(b)) + if body != exp { + t.Fatalf("Expected request to be %s`, got `%s`.", exp, body) + } + return nil, 200, nil + }, + url: "/v1/agent/check/update/service:nodeId", + }, + &mockRegistry{ + fn: func(r *http.Request) ([]byte, int, error) { + b, _ := encodeData([]*consul.HealthCheck{ + newHealthCheck("nodeIdsdfsd", "service1", "passing"), + }) + return b, 200, nil + }, + url: "/v1/health/checks/service1", + }, + ) + defer cl() + + service := ®istry.Service{ + Name: "service1", + Nodes: []*registry.Node{ + { + Address: "address", + Id: "nodeId", + }, + }, + } + rOpts := []registry.RegisterOption{registry.RegisterTTL(time.Second)} + err := cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + err = cr.Register(service, rOpts...) + if err != nil { + t.Fatal("Unexpected error", err) + } + if registerCalled >= 3 { + t.Fatalf("Expected register run time to be %d`, got `%d`.", 2, registerCalled) + } + if deregisterCalled < 1 { + t.Fatalf("Expected deregister run time to be %d`, got `%d`.", 1, deregisterCalled) + } +} diff --git a/v4/registry/consul/watcher_test.go b/v4/registry/consul/watcher_test.go index 95dfbed7..358c990b 100644 --- a/v4/registry/consul/watcher_test.go +++ b/v4/registry/consul/watcher_test.go @@ -66,6 +66,7 @@ func newWatcher() *consulWatcher { func newHealthCheck(node, name, status string) *api.HealthCheck { return &api.HealthCheck{ + ServiceID: node, Node: node, Name: name, Status: status, From 4eaf9f172c4ff453a5ec9cefdb7517db481eff83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hunyadv=C3=A1ri=20P=C3=A9ter?= Date: Tue, 26 Sep 2023 11:44:02 +0200 Subject: [PATCH 2/2] feature: v3 ability to periodically check the service registration in the consul server --- v3/registry/consul/consul.go | 93 ++++++++++++++++++++--------------- v3/registry/consul/options.go | 15 ++++-- 2 files changed, 65 insertions(+), 43 deletions(-) diff --git a/v3/registry/consul/consul.go b/v3/registry/consul/consul.go index ddf00117..7deab431 100644 --- a/v3/registry/consul/consul.go +++ b/v3/registry/consul/consul.go @@ -168,6 +168,7 @@ func (c *consulRegistry) Deregister(s *registry.Service, opts ...registry.Deregi c.Unlock() node := s.Nodes[0] + return c.Client().Agent().ServiceDeregister(node.Id) } @@ -176,12 +177,13 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register return errors.New("Require at least one node") } - var regTCPCheck bool - var regInterval time.Duration - var regHTTPCheck bool - var httpCheckConfig consul.AgentServiceCheck - - var options registry.RegisterOptions + var ( + regTCPCheck bool + regInterval time.Duration + regHTTPCheck bool + httpCheckConfig consul.AgentServiceCheck + options registry.RegisterOptions + ) for _, o := range opts { o(&options) } @@ -191,9 +193,9 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register regTCPCheck = true regInterval = tcpCheckInterval } - var ok bool - if httpCheckConfig, ok = c.opts.Context.Value("consul_http_check_config").(consul.AgentServiceCheck); ok { + if conf, ok := c.opts.Context.Value("consul_http_check_config").(consul.AgentServiceCheck); ok { regHTTPCheck = true + httpCheckConfig = conf } } @@ -212,37 +214,8 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register lastChecked := c.lastChecked[s.Name] c.Unlock() - // if it's already registered and matches then just pass the check - if ok && v == h { - if options.TTL == time.Duration(0) { - // ensure that our service hasn't been deregistered by Consul - if time.Since(lastChecked) <= getDeregisterTTL(regInterval) { - return nil - } - services, _, err := c.Client().Health().Checks(s.Name, c.queryOptions) - if err == nil { - for _, v := range services { - if v.ServiceID == node.Id { - return nil - } - } - } - } else { - // if the err is nil we're all good, bail out - // if not, we don't know what the state is, so full re-register - if err := c.Client().Agent().PassTTL("service:"+node.Id, ""); err == nil { - return nil - } - } - } - - // encode the tags - tags := encodeMetadata(node.Metadata) - tags = append(tags, encodeEndpoints(s.Endpoints)...) - tags = append(tags, encodeVersion(s.Version)...) - var check *consul.AgentServiceCheck - + checkTTL := regInterval if regTCPCheck { deregTTL := getDeregisterTTL(regInterval) @@ -255,6 +228,7 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register } else if regHTTPCheck { interval, _ := time.ParseDuration(httpCheckConfig.Interval) deregTTL := getDeregisterTTL(interval) + checkTTL = interval host, _, _ := net.SplitHostPort(node.Address) healthCheckURI := strings.Replace(httpCheckConfig.HTTP, "{host}", host, 1) @@ -269,12 +243,53 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register // if the TTL is greater than 0 create an associated check } else if options.TTL > time.Duration(0) { deregTTL := getDeregisterTTL(options.TTL) + checkTTL = options.TTL check = &consul.AgentServiceCheck{ TTL: fmt.Sprintf("%v", options.TTL), DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL), } } + if c.opts.Context != nil { + if ttl, ok := c.opts.Context.Value("consul_check_ttl").(time.Duration); ok { + checkTTL = ttl + } + } + + // if it's already registered and matches then just pass the check + if ok && v == h { + passing := false + if time.Since(lastChecked) > checkTTL { + services, _, _ := c.Client().Health().Checks(s.Name, c.queryOptions) + for _, service := range services { + if service.ServiceID == node.Id && service.Status == "passing" { + passing = true + c.Lock() + c.lastChecked[s.Name] = time.Now() + c.Unlock() + break + } + } + } else { + passing = true + } + if passing { + if options.TTL == time.Duration(0) { + return nil + } + // if the err is nil we're all good, bail out + // if not, we don't know what the state is, so full re-register + if err := c.Client().Agent().UpdateTTL("service:"+node.Id, "", "pass"); err == nil { + return nil + } + } + c.Deregister(s) + } + + // encode the tags + tags := encodeMetadata(node.Metadata) + tags = append(tags, encodeEndpoints(s.Endpoints)...) + tags = append(tags, encodeVersion(s.Version)...) host, pt, _ := net.SplitHostPort(node.Address) if host == "" { @@ -315,7 +330,7 @@ func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Register } // pass the healthcheck - return c.Client().Agent().PassTTL("service:"+node.Id, "") + return c.Client().Agent().UpdateTTL("service:"+node.Id, "", "pass") } func (c *consulRegistry) GetService(name string, opts ...registry.GetOption) ([]*registry.Service, error) { diff --git a/v3/registry/consul/options.go b/v3/registry/consul/options.go index e6091202..298f9d8f 100644 --- a/v3/registry/consul/options.go +++ b/v3/registry/consul/options.go @@ -28,6 +28,17 @@ func Config(c *consul.Config) registry.Option { } } +// CheckTTL allows you to periodically check the registration of the service to ensure +// that the registration actually exists in the consul +func CheckTTL(t time.Duration) registry.Option { + return func(o *registry.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, "consul_check_ttl", t) + } +} + // AllowStale sets whether any Consul server (non-leader) can service // a read. This allows for lower latency and higher throughput // at the cost of potentially stale data. @@ -35,7 +46,6 @@ func Config(c *consul.Config) registry.Option { // Defaults to true. // // [1] https://www.consul.io/docs/agent/options.html#allow_stale -// func AllowStale(v bool) registry.Option { return func(o *registry.Options) { if o.Context == nil { @@ -49,7 +59,6 @@ func AllowStale(v bool) registry.Option { // Consul. See `Consul API` for more information [1]. // // [1] https://godoc.org/github.com/hashicorp/consul/api#QueryOptions -// func QueryOptions(q *consul.QueryOptions) registry.Option { return func(o *registry.Options) { if q == nil { @@ -62,13 +71,11 @@ func QueryOptions(q *consul.QueryOptions) registry.Option { } } -// // TCPCheck will tell the service provider to check the service address // and port every `t` interval. It will enabled only if `t` is greater than 0. // See `TCP + Interval` for more information [1]. // // [1] https://www.consul.io/docs/agent/checks.html -// func TCPCheck(t time.Duration) registry.Option { return func(o *registry.Options) { if t <= time.Duration(0) {