Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support passing a slog.logger #16

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pkg/assured/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"strconv"
Expand Down Expand Up @@ -32,7 +31,7 @@ func NewClient(opts ...Option) *Client {
var err error
c.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", c.Options.Port))
if err != nil {
slog.With("error", err, "port", c.Options.Port).Error("unable to create http listener")
c.logger.With("error", err, "port", c.Options.Port).Error("unable to create http listener")
} else {
c.Options.Port = c.listener.Addr().(*net.TCPAddr).Port
}
Expand Down
22 changes: 12 additions & 10 deletions pkg/assured/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type AssuredEndpoints struct {
madeCalls *CallStore
callbackCalls *CallStore
trackMadeCalls bool
logger *slog.Logger
}

// NewAssuredEndpoints creates a new instance of assured endpoints
Expand All @@ -29,6 +30,7 @@ func NewAssuredEndpoints(options Options) *AssuredEndpoints {
callbackCalls: NewCallStore(),
httpClient: options.httpClient,
trackMadeCalls: options.trackMadeCalls,
logger: options.logger,
}
}

Expand All @@ -47,15 +49,15 @@ func (a *AssuredEndpoints) WrappedEndpoint(handler func(context.Context, *Call)
// GivenEndpoint is used to stub out a call for a given path
func (a *AssuredEndpoints) GivenEndpoint(ctx context.Context, call *Call) (interface{}, error) {
a.assuredCalls.Add(call)
slog.With("path", call.ID()).Info("assured call set")
a.logger.With("path", call.ID()).Info("assured call set")

return call, nil
}

// GivenCallbackEndpoint is used to stub out callbacks for a callback key
func (a *AssuredEndpoints) GivenCallbackEndpoint(ctx context.Context, call *Call) (interface{}, error) {
a.callbackCalls.AddAt(call.Headers[AssuredCallbackKey], call)
slog.With("key", call.Headers[AssuredCallbackKey], "target", call.Headers[AssuredCallbackTarget]).Info("assured callback set")
a.logger.With("key", call.Headers[AssuredCallbackKey], "target", call.Headers[AssuredCallbackTarget]).Info("assured callback set")

return call, nil
}
Expand All @@ -64,7 +66,7 @@ func (a *AssuredEndpoints) GivenCallbackEndpoint(ctx context.Context, call *Call
func (a *AssuredEndpoints) WhenEndpoint(ctx context.Context, call *Call) (interface{}, error) {
calls := a.assuredCalls.Get(call.ID())
if len(calls) == 0 {
slog.With("path", call.ID()).Info("assured call not found")
a.logger.With("path", call.ID()).Info("assured call not found")
return nil, errors.New("No assured calls")
}

Expand All @@ -84,7 +86,7 @@ func (a *AssuredEndpoints) WhenEndpoint(ctx context.Context, call *Call) (interf
time.Sleep(time.Duration(delay) * time.Second)
}

slog.With("path", call.ID()).Info("assured call responded")
a.logger.With("path", call.ID()).Info("assured call responded")
return assured, nil
}

Expand All @@ -100,10 +102,10 @@ func (a *AssuredEndpoints) VerifyEndpoint(ctx context.Context, call *Call) (inte
func (a *AssuredEndpoints) ClearEndpoint(ctx context.Context, call *Call) (interface{}, error) {
a.assuredCalls.Clear(call.ID())
a.madeCalls.Clear(call.ID())
slog.With("path", call.ID()).Info("cleared calls for path")
a.logger.With("path", call.ID()).Info("cleared calls for path")
if call.Headers[AssuredCallbackKey] != "" {
a.callbackCalls.Clear(call.Headers[AssuredCallbackKey])
slog.With("key", call.Headers[AssuredCallbackKey]).Info("cleared calls for key")
a.logger.With("key", call.Headers[AssuredCallbackKey]).Info("cleared calls for key")
}

return nil, nil
Expand All @@ -114,7 +116,7 @@ func (a *AssuredEndpoints) ClearAllEndpoint(ctx context.Context, i interface{})
a.assuredCalls.ClearAll()
a.madeCalls.ClearAll()
a.callbackCalls.ClearAll()
slog.Info("cleared all calls")
a.logger.Info("cleared all calls")

return nil, nil
}
Expand All @@ -127,7 +129,7 @@ func (a *AssuredEndpoints) sendCallback(target string, call *Call) {
}
req, err := http.NewRequest(call.Method, target, bytes.NewBuffer(call.Response))
if err != nil {
slog.With("target", target, "error", err).Info("failed to build callback request")
a.logger.With("target", target, "error", err).Info("failed to build callback request")
return
}
for key, value := range call.Headers {
Expand All @@ -137,8 +139,8 @@ func (a *AssuredEndpoints) sendCallback(target string, call *Call) {
time.Sleep(time.Duration(delay) * time.Second)
resp, err := a.httpClient.Do(req)
if err != nil {
slog.With("target", target, "error", err).Info("failed to reach callback target")
a.logger.With("target", target, "error", err).Info("failed to reach callback target")
return
}
slog.With("target", target, "status_code", resp.StatusCode).Info("sent callback to target")
a.logger.With("target", target, "status_code", resp.StatusCode).Info("sent callback to target")
}
9 changes: 9 additions & 0 deletions pkg/assured/endpoints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package assured

import (
"context"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -16,6 +17,7 @@ func TestNewAssuredEndpoints(t *testing.T) {
assuredCalls: NewCallStore(),
madeCalls: NewCallStore(),
trackMadeCalls: true,
logger: slog.Default(),
}
actual := NewAssuredEndpoints(DefaultOptions)

Expand Down Expand Up @@ -121,6 +123,7 @@ func TestWhenEndpointSuccess(t *testing.T) {
madeCalls: NewCallStore(),
callbackCalls: NewCallStore(),
trackMadeCalls: true,
logger: slog.Default(),
}
expected := map[string][]*Call{
"GET:test/assured": {testCall2(), testCall1()},
Expand Down Expand Up @@ -153,6 +156,7 @@ func TestWhenEndpointSuccessTrackingDisabled(t *testing.T) {
madeCalls: NewCallStore(),
callbackCalls: NewCallStore(),
trackMadeCalls: false,
logger: slog.Default(),
}
expected := map[string][]*Call{
"GET:test/assured": {testCall2(), testCall1()},
Expand Down Expand Up @@ -198,6 +202,7 @@ func TestWhenEndpointSuccessCallbacks(t *testing.T) {
data: map[string][]*Call{"call-key": {call}},
},
trackMadeCalls: true,
logger: slog.Default(),
}

c, err := endpoints.WhenEndpoint(context.TODO(), assured)
Expand Down Expand Up @@ -230,6 +235,7 @@ func TestWhenEndpointSuccessDelayed(t *testing.T) {
data: map[string][]*Call{"call-key": {call}},
},
trackMadeCalls: true,
logger: slog.Default(),
}
start := time.Now()
c, err := endpoints.WhenEndpoint(context.TODO(), assured)
Expand Down Expand Up @@ -310,6 +316,7 @@ func TestClearEndpointSuccess(t *testing.T) {
madeCalls: fullAssuredCalls,
callbackCalls: NewCallStore(),
trackMadeCalls: true,
logger: slog.Default(),
}
expected := map[string][]*Call{
"POST:teapot/assured": {testCall3()},
Expand Down Expand Up @@ -348,6 +355,7 @@ func TestClearEndpointSuccessCallback(t *testing.T) {
},
},
trackMadeCalls: true,
logger: slog.Default(),
}

c, err := endpoints.ClearEndpoint(context.TODO(), testCallback())
Expand All @@ -365,6 +373,7 @@ func TestClearAllEndpointSuccess(t *testing.T) {
madeCalls: fullAssuredCalls,
callbackCalls: fullAssuredCalls,
trackMadeCalls: true,
logger: slog.Default(),
}

c, err := endpoints.ClearAllEndpoint(context.TODO(), nil)
Expand Down
14 changes: 14 additions & 0 deletions pkg/assured/options.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package assured

import (
"log/slog"
"net/http"
)

var DefaultOptions = Options{
httpClient: http.DefaultClient,
host: "localhost",
trackMadeCalls: true,
logger: slog.Default(),
}

// Option is a function on that configures rest assured settings
Expand All @@ -32,6 +34,9 @@ type Options struct {

// trackMadeCalls toggles storing the requests made against the rest assured server. Defaults to true.
trackMadeCalls bool

// logger to use for logging. Defaults the default logger.
logger *slog.Logger
}

// WithHTTPClient sets the http client option.
Expand Down Expand Up @@ -74,6 +79,15 @@ func WithCallTracking(t bool) Option {
}
}

// WithCallTracking sets the trackMadeCalls option.
func WithLogger(l *slog.Logger) Option {
return func(o *Options) {
if l != nil {
o.logger = l
}
}
}

func (o *Options) applyOptions(opts ...Option) {
for _, opt := range opts {
opt(o)
Expand Down
9 changes: 9 additions & 0 deletions pkg/assured/options_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package assured

import (
"log/slog"
"net/http"
"os"
"reflect"
"testing"
)
Expand Down Expand Up @@ -40,6 +42,13 @@ func Test_applyOptions(t *testing.T) {
trackMadeCalls: true,
},
},
{
name: "with logger",
option: WithLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{}))),
want: Options{
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{})),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down