From 6aa824c145efa72d6500f3365e90a560308b3999 Mon Sep 17 00:00:00 2001 From: Philipp Date: Tue, 12 Sep 2023 21:25:57 +0200 Subject: [PATCH] new: WithHttpConnection option --- clientoptions.go | 56 +++++++++++++++++++++++++++++++++++++++++++++- httpserver_test.go | 30 ++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/clientoptions.go b/clientoptions.go index f5465464..4aed5e53 100644 --- a/clientoptions.go +++ b/clientoptions.go @@ -1,8 +1,10 @@ package signalr import ( + "context" "errors" "fmt" + "github.com/cenkalti/backoff/v4" ) @@ -36,6 +38,58 @@ func WithConnector(connectionFactory func() (Connection, error)) func(Party) err } } +// HttpConnectionFactory is a connectionFactory for WithConnector which first tries to create a connection +// with WebSockets (if it is allowed by the HttpConnection options) and if this fails, falls back to a SSE based connection. +func HttpConnectionFactory(ctx context.Context, address string, options ...func(*httpConnection) error) (Connection, error) { + conn := &httpConnection{} + for i, option := range options { + if err := option(conn); err != nil { + return nil, err + } + if conn.transports != nil { + // Remove the WithTransports option + options = append(options[:i], options[i+1:]...) + break + } + } + // If no WithTransports was given, NewHTTPConnection fallbacks to both + if conn.transports == nil { + conn.transports = []TransportType{TransportWebSockets, TransportServerSentEvents} + } + + for _, transport := range conn.transports { + // If Websockets are allowed, we try to connect with these + if transport == TransportWebSockets { + wsOptions := append(options, WithTransports(TransportWebSockets)) + conn, err := NewHTTPConnection(ctx, address, wsOptions...) + // If this is ok, return the conn + if err == nil { + return conn, err + } + break + } + } + for _, transport := range conn.transports { + // If SSE is allowed, with fallback to try these + if transport == TransportServerSentEvents { + sseOptions := append(options, WithTransports(TransportServerSentEvents)) + return NewHTTPConnection(ctx, address, sseOptions...) + } + } + // None of the transports worked + return nil, fmt.Errorf("can not connect with supported transports: %v", conn.transports) +} + +// WithHttpConnection first tries to create a connection +// with WebSockets (if it is allowed by the HttpConnection options) and if this fails, falls back to a SSE based connection. +// This strategy is also used for auto reconnect if this option is used. +// WithHttpConnection is a shortcut for WithConnector(HttpConnectionFactory(...)) +func WithHttpConnection(ctx context.Context, address string, options ...func(*httpConnection) error) func(Party) error { + return WithConnector(func() (Connection, error) { + return HttpConnectionFactory(ctx, address, options...) + }) +} + // WithReceiver sets the object which will receive server side calls to client methods (e.g. callbacks) func WithReceiver(receiver interface{}) func(Party) error { return func(party Party) error { @@ -64,7 +118,7 @@ func WithBackoff(backoffFactory func() backoff.BackOff) func(party Party) error } // TransferFormat sets the transfer format used on the transport. Allowed values are "Text" and "Binary" -func TransferFormat(format string) func(Party) error { +func TransferFormat(format TransferFormatType) func(Party) error { return func(p Party) error { if c, ok := p.(*client); ok { switch format { diff --git a/httpserver_test.go b/httpserver_test.go index aa9ca6c7..47466245 100644 --- a/httpserver_test.go +++ b/httpserver_test.go @@ -48,7 +48,7 @@ var _ = Describe("HTTP server", func() { transport = TransportServerSentEvents transferFormat = TransferFormatText } - Context(fmt.Sprintf("%v %v", transport[0], transport[1]), func() { + Context(fmt.Sprintf("%v %v", transport, transferFormat), func() { Context("A correct negotiation request is sent", func() { It(fmt.Sprintf("should send a correct negotiation response with support for %v with text protocol", transport), func(done Done) { // Start server @@ -177,6 +177,34 @@ var _ = Describe("HTTP server", func() { }) }) +var _ = Describe("HTTP client", func() { + Context("WithHttpConnection", func() { + It("should fallback to SSE (this can only be tested when httpConnection is tampered with)", func(done Done) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server, err := NewServer(ctx, SimpleHubFactory(&addHub{}), HTTPTransports(TransportWebSockets, TransportServerSentEvents), testLoggerOption()) + Expect(err).NotTo(HaveOccurred()) + router := http.NewServeMux() + server.MapHTTP(WithHTTPServeMux(router), "/hub") + testServer := httptest.NewServer(router) + url, _ := url.Parse(testServer.URL) + port, _ := strconv.Atoi(url.Port()) + waitForPort(port) + + client, err := NewClient(ctx, WithHttpConnection(ctx, fmt.Sprintf("http://127.0.0.1:%v/hub", port))) + Expect(err).NotTo(HaveOccurred()) + + client.Start() + Expect(<-client.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) + result := <-client.Invoke("Add2", 2) + Expect(result.Error).NotTo(HaveOccurred()) + + close(done) + }, 2.0) + }) +}) + type nonProtocolLogger struct { logger StructuredLogger }