Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle client init errors #30

Merged
merged 11 commits into from
Aug 14, 2023
5 changes: 0 additions & 5 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
linters:
disable-all: true
enable:
- deadcode
- errcheck
- staticcheck
- unused
- gosimple
- ineffassign
- stylecheck
- structcheck
- typecheck
- varcheck
- unconvert
- bodyclose
- dupl
- goconst
- gocyclo
- gofmt
- golint
- interfacer
- lll
- misspell
- nakedret
Expand Down
101 changes: 65 additions & 36 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package courier

import (
"context"
"errors"
"fmt"
"os"
"sync"
Expand All @@ -11,6 +12,9 @@ import (
mqtt "github.com/eclipse/paho.mqtt.golang"
)

// ErrClientNotInitialized is returned when the client is not initialized
var ErrClientNotInitialized = errors.New("courier: client not initialized")

var newClientFunc = defaultNewClientFunc()

// Client allows to communicate with an MQTT broker
Expand Down Expand Up @@ -56,54 +60,33 @@ func NewClient(opts ...ClientOption) (*Client, error) {
}

// IsConnected checks whether the client is connected to the broker
func (c *Client) IsConnected() (online bool) {
c.execute(func(cc mqtt.Client) {
online = cc != nil && cc.IsConnectionOpen()
func (c *Client) IsConnected() bool {
var online bool

err := c.execute(func(cc mqtt.Client) {
online = cc.IsConnectionOpen()
})

return
return err == nil && online
}

// Start will attempt to connect to the broker.
func (c *Client) Start() (err error) {
if len(c.options.brokerAddress) != 0 {
c.execute(func(cc mqtt.Client) {
t := cc.Connect()
if !t.WaitTimeout(c.options.connectTimeout) {
err = ErrConnectTimeout

return
}

err = t.Error()
})
func (c *Client) Start() error {
if err := c.runConnect(); err != nil {
return err
}

if c.options.resolver != nil {
// try first connect attempt on start, then start a watcher on channel
select {
case <-time.After(c.options.connectTimeout):
err = ErrConnectTimeout

return
case addrs := <-c.options.resolver.UpdateChan():
c.attemptConnection(addrs)
}

go c.watchAddressUpdates(c.options.resolver)
return c.runResolver()
}

return
return nil
}

// Stop will disconnect from the broker and finish up any pending work on internal
// communication workers. This can only block until the period configured with
// the ClientOption WithGracefulShutdownPeriod.
func (c *Client) Stop() {
c.execute(func(cc mqtt.Client) {
cc.Disconnect(uint(c.options.gracefulShutdownPeriod / time.Millisecond))
})
}
func (c *Client) Stop() { _ = c.stop() }

// Run will start running the Client. This makes Client compatible with github.com/gojekfarm/xrun package.
// https://pkg.go.dev/github.com/gojekfarm/xrun
Expand All @@ -117,16 +100,27 @@ func (c *Client) Run(ctx context.Context) error {
}

<-ctx.Done()
c.Stop()

return nil
return c.stop()
}

func (c *Client) stop() error {
return c.execute(func(cc mqtt.Client) {
cc.Disconnect(uint(c.options.gracefulShutdownPeriod / time.Millisecond))
})
}

func (c *Client) execute(f func(mqtt.Client)) {
func (c *Client) execute(f func(mqtt.Client)) error {
c.mu.RLock()
defer c.mu.RUnlock()

if c.mqttClient == nil {
return ErrClientNotInitialized
}

f(c.mqttClient)

return nil
}

func (c *Client) handleToken(ctx context.Context, t mqtt.Token, timeoutErr error) error {
Expand Down Expand Up @@ -158,6 +152,41 @@ func (c *Client) waitForToken(ctx context.Context, t mqtt.Token, timeoutErr erro
return nil
}

func (c *Client) runResolver() error {
// try first connect attempt on start, then start a watcher on channel
select {
case <-time.After(c.options.connectTimeout):
return ErrConnectTimeout
case addrs := <-c.options.resolver.UpdateChan():
c.attemptConnection(addrs)
}

go c.watchAddressUpdates(c.options.resolver)

return nil
}

func (c *Client) runConnect() (err error) {
if len(c.options.brokerAddress) == 0 {
return nil
}

if e := c.execute(func(cc mqtt.Client) {
t := cc.Connect()
if !t.WaitTimeout(c.options.connectTimeout) {
err = ErrConnectTimeout

return
}

err = t.Error()
}); e != nil {
err = e
}

return
}

func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions()

Expand Down
6 changes: 4 additions & 2 deletions client_publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ func publishHandler(c *Client) Publisher {
}

o := composeOptions(opts)
c.execute(func(cc mqtt.Client) {
if e := c.execute(func(cc mqtt.Client) {
t := cc.Publish(topic, o.qos, o.retained, buf.Bytes())
err = c.handleToken(ctx, t, ErrPublishTimeout)
})
}); e != nil {
err = e
}

return
})
Expand Down
8 changes: 8 additions & 0 deletions client_publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ func (s *ClientPublishSuite) TestPublish() {
tk.AssertExpectations(s.T())
})
}

s.Run("PublishOnUninitializedClient", func() {
c := &Client{
options: &clientOptions{newEncoder: DefaultEncoderFunc},
}
c.publisher = publishHandler(c)
s.True(errors.Is(c.Publish(context.Background(), "topic", "data"), ErrClientNotInitialized))
})
}

func (s *ClientPublishSuite) TestPublishMiddleware() {
Expand Down
4 changes: 4 additions & 0 deletions client_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func (c *Client) newClient(addrs []TCPAddress, attempt int) mqtt.Client {

if err := t.Error(); err != nil {
// TODO: add retry backoff or use ExponentialStartStrategy utility
if c.options.onConnectionLostHandler != nil {
c.options.onConnectionLostHandler(err)
}

return c.newClient(addrs, attempt+1)
}

Expand Down
22 changes: 17 additions & 5 deletions client_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ import (
"time"

mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestClient_newClient(t *testing.T) {
tests := []struct {
name string
addrs []TCPAddress
newClientFunc func(*mqtt.ClientOptions) mqtt.Client
name string
addrs []TCPAddress
newClientFunc func(*mqtt.ClientOptions) mqtt.Client
onConnLostAssert func(*testing.T, error)
}{
{
name: "success_attempt_1",
Expand Down Expand Up @@ -92,6 +94,9 @@ func TestClient_newClient(t *testing.T) {
Port: 8888,
},
},
onConnLostAssert: func(t *testing.T, err error) {
assert.EqualError(t, err, "some error")
},
newClientFunc: func(o *mqtt.ClientOptions) mqtt.Client {
if o.Servers[0].String() != "tcp://localhost:1883" {
panic(o.Servers)
Expand Down Expand Up @@ -123,15 +128,22 @@ func TestClient_newClient(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
options: defaultClientOptions(),
opts := defaultClientOptions()
if tt.onConnLostAssert != nil {
opts.onConnectionLostHandler = func(err error) {
tt.onConnLostAssert(t, err)
}
}

c := &Client{options: opts}
newClientFunc.Store(tt.newClientFunc)
got := c.newClient(tt.addrs, 0)

got.(*mockClient).AssertExpectations(t)
})
}

newClientFunc.Store(mqtt.NewClient)
}

func TestClient_watchAddressUpdates(t *testing.T) {
Expand Down
12 changes: 8 additions & 4 deletions client_subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,22 @@ func subscriberFuncs(c *Client) Subscriber {
return NewSubscriberFuncs(
func(ctx context.Context, topic string, callback MessageHandler, opts ...Option) (err error) {
o := composeOptions(opts)
c.execute(func(cc mqtt.Client) {
if e := c.execute(func(cc mqtt.Client) {
t := cc.Subscribe(topic, o.qos, callbackWrapper(c, callback))
err = c.handleToken(ctx, t, ErrSubscribeTimeout)
})
}); e != nil {
err = e
}

return
},
func(ctx context.Context, topicsWithQos map[string]QOSLevel, callback MessageHandler) (err error) {
c.execute(func(cc mqtt.Client) {
if e := c.execute(func(cc mqtt.Client) {
t := cc.SubscribeMultiple(routeFilters(topicsWithQos), callbackWrapper(c, callback))
err = c.handleToken(ctx, t, ErrSubscribeMultipleTimeout)
})
}); e != nil {
err = e
}

return
},
Expand Down
7 changes: 7 additions & 0 deletions client_subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ func (s *ClientSubscribeSuite) TestSubscribe() {
tk.AssertExpectations(s.T())
})
}

s.Run("SubscribeOnUninitializedClient", func() {
c := &Client{}
c.subscriber = subscriberFuncs(c)
s.True(errors.Is(c.Subscribe(context.Background(), "topic", callback), ErrClientNotInitialized))
s.True(errors.Is(c.SubscribeMultiple(context.Background(), map[string]QOSLevel{"topic": QOSOne}, callback), ErrClientNotInitialized))
})
}

func (s *ClientSubscribeSuite) TestSubscribeMultiple() {
Expand Down
8 changes: 8 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ func (s *ClientSuite) TestStart() {
wantErr: errConnect,
},
}

for _, t := range tests {
s.Run(t.name, func() {
if t.newClientFunc != nil {
Expand Down Expand Up @@ -155,6 +156,13 @@ func (s *ClientSuite) TestStart() {
})
}
newClientFunc.Store(mqtt.NewClient)

s.Run("WithUninitializedClient", func() {
c := &Client{
options: &clientOptions{brokerAddress: "localhost:1883"},
}
s.True(errors.Is(c.Run(context.Background()), ErrClientNotInitialized))
})
}

func TestNewClientWithResolverOption(t *testing.T) {
Expand Down
6 changes: 4 additions & 2 deletions client_unsubscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ func (c *Client) UseUnsubscriberMiddleware(mwf ...UnsubscriberMiddlewareFunc) {

func unsubscriberHandler(c *Client) Unsubscriber {
return UnsubscriberFunc(func(ctx context.Context, topics ...string) (err error) {
c.execute(func(cc mqtt.Client) {
if e := c.execute(func(cc mqtt.Client) {
t := cc.Unsubscribe(topics...)
err = c.handleToken(ctx, t, ErrUnsubscribeTimeout)
})
}); e != nil {
err = e
}

return
})
Expand Down
7 changes: 7 additions & 0 deletions client_unsubscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func (s *ClientUnsubscribeSuite) TestUnsubscribe() {
wantErr: true,
},
}

for _, t := range testcases {
s.Run(t.name, func() {
c, err := NewClient(defOpts...)
Expand Down Expand Up @@ -158,6 +159,12 @@ func (s *ClientUnsubscribeSuite) TestUnsubscribe() {
tk.AssertExpectations(s.T())
})
}

s.Run("UnsubscribeOnUninitializedClient", func() {
c := &Client{}
c.unsubscriber = unsubscriberHandler(c)
s.True(errors.Is(c.Unsubscribe(context.Background(), topics...), ErrClientNotInitialized))
})
}

func (s *ClientUnsubscribeSuite) TestUnsubscribeMiddleware() {
Expand Down
Loading