diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 681e4e0..131dbed 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,7 @@ jobs: uses: golangci/golangci-lint-action@v3 with: # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. - version: v1.57 + version: v1.61 test: name: Test and Coverage diff --git a/client_test.go b/client_test.go index 1f40466..6c75ec0 100644 --- a/client_test.go +++ b/client_test.go @@ -85,6 +85,10 @@ func (s *simpleHub) Callback(arg1 string) { s.Hub.Clients().Caller().Send("OnCallback", strings.ToUpper(arg1)) } +func (s *simpleHub) SwearBack(arg1 string) { + s.Hub.Clients().Caller().Send("#@!", strings.ToUpper(arg1)) +} + func (s *simpleHub) ReadStream(i int) chan string { ch := make(chan string) go func() { @@ -279,6 +283,31 @@ var _ = Describe("Client", func() { cancelClient() close(done) }, 5.0) + It("should invoke a server method and get the result via callback with alternate name", func(done Done) { + receiver := &simpleReceiver{} + _, client, _, cancelClient := getTestBed(receiver, formatOption) + receiver.result.Store("x") + errCh := client.Send("SwearBack", "low") + ch := make(chan string, 1) + go func() { + for { + if result, ok := receiver.result.Load().(string); ok { + if result != "x" { + ch <- result + break + } + } + } + }() + select { + case val := <-ch: + Expect(val).To(Equal("LOW")) + case err := <-errCh: + Expect(err).NotTo(HaveOccurred()) + } + cancelClient() + close(done) + }, 5.0) It("should invoke a server method and return the error when arguments don't match", func(done Done) { receiver := &simpleReceiver{} _, client, _, cancelClient := getTestBed(receiver, formatOption) @@ -480,7 +509,8 @@ func getTestBed(receiver interface{}, formatOption func(Party) error) (Server, C // Create the Client var ctx context.Context ctx, cancelClient := context.WithCancel(context.Background()) - client, _ := NewClient(ctx, WithConnection(cliConn), WithReceiver(receiver), testLoggerOption(), formatOption) + client, _ := NewClient(ctx, WithConnection(cliConn), WithReceiver(receiver), + WithAlternateMethodName("OnCallback", "#@!"), testLoggerOption(), formatOption) // Start it client.Start() return server, client, cliConn, cancelClient diff --git a/invocation_test.go b/invocation_test.go index dcecaf1..4ba24f3 100644 --- a/invocation_test.go +++ b/invocation_test.go @@ -81,6 +81,18 @@ var _ = Describe("Invocation", func() { close(done) }, 2.0) }) + Context("When invoked by the client with alternate method name", func() { + It("should be invoked and return a completion", func(done Done) { + conn.ClientSend(`{"type":1,"invocationId": "123","target":"."}`) + Expect(<-invocationQueue).To(Equal("Simple()")) + recv := (<-conn.received).(completionMessage) + Expect(recv).NotTo(BeNil()) + Expect(recv.InvocationID).To(Equal("123")) + Expect(recv.Result).To(BeNil()) + Expect(recv.Error).To(Equal("")) + close(done) + }, 2.0) + }) Context("When invoked by the client two times in one frame", func() { It("should be invoked and return a completion", func(done Done) { conn.ClientSend(`{"type":1,"invocationId": "123","target":"simple"}`) diff --git a/loop.go b/loop.go index 02b16b5..5716b0a 100644 --- a/loop.go +++ b/loop.go @@ -192,7 +192,8 @@ func (l *loop) GetNewID() string { func (l *loop) handleInvocationMessage(invocation invocationMessage) { _ = l.dbg.Log(evt, msgRecv, msg, fmtMsg(invocation)) // Transient hub, dispatch invocation here - if method, ok := getMethod(l.party.invocationTarget(l.hubConn), invocation.Target); !ok { + methodName := l.party.getMethodNameByAlternateName(invocation.Target) + if method, ok := getMethod(l.party.invocationTarget(l.hubConn), methodName); !ok { // Unable to find the method _ = l.info.Log(evt, "getMethod", "error", "missing method", "name", invocation.Target, react, "send completion with error") _ = l.hubConn.Completion(invocation.InvocationID, nil, fmt.Sprintf("Unknown method %s", invocation.Target)) diff --git a/options.go b/options.go index abf3754..1ca5270 100644 --- a/options.go +++ b/options.go @@ -129,3 +129,13 @@ func buildInfoDebugLogger(logger log.Logger, debug bool) (log.Logger, log.Logger debugLogger := log.With(&recoverLogger{level.Debug(logger)}, "caller", log.DefaultCaller) return infoLogger, debugLogger } + +// WithAlternateMethodName sets an alternate method name for hub or receiver method. This might be necessary +// if the invocation target used by existing clients for invocations or by existing servers for callback invocations +// is not a name that can be used as a Go method name. +func WithAlternateMethodName(methodName, alternateName string) func(Party) error { + return func(p Party) error { + p.setAlternateMethodName(methodName, alternateName) + return nil + } +} diff --git a/party.go b/party.go index 174b2c7..df852bb 100644 --- a/party.go +++ b/party.go @@ -44,6 +44,9 @@ type Party interface { enableDetailedErrors() bool setEnableDetailedErrors(enable bool) + setAlternateMethodName(methodName, alternateName string) + getMethodNameByAlternateName(alternateName string) (methodName string) + loggers() (info StructuredLogger, dbg StructuredLogger) setLoggers(info StructuredLogger, dbg StructuredLogger) @@ -69,6 +72,7 @@ func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger _enableDetailedErrors: false, _insecureSkipVerify: false, _originPatterns: nil, + methodNamesByAlternateName: make(map[string]string), info: info, dbg: dbg, } @@ -86,6 +90,7 @@ type partyBase struct { _enableDetailedErrors bool _insecureSkipVerify bool _originPatterns []string + methodNamesByAlternateName map[string]string info StructuredLogger dbg StructuredLogger wg sync.WaitGroup @@ -181,3 +186,15 @@ func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) { func (p *partyBase) waitGroup() *sync.WaitGroup { return &p.wg } + +func (p *partyBase) setAlternateMethodName(methodName, alternateName string) { + p.methodNamesByAlternateName[alternateName] = methodName +} + +func (p *partyBase) getMethodNameByAlternateName(alternateName string) (methodName string) { + var ok bool + if methodName, ok = p.methodNamesByAlternateName[alternateName]; !ok { + methodName = alternateName + } + return methodName +} diff --git a/signalr_suite_test.go b/signalr_suite_test.go index b2c3f54..77fe798 100644 --- a/signalr_suite_test.go +++ b/signalr_suite_test.go @@ -18,7 +18,9 @@ func connect(hubProto HubInterface) (Server, *testingConnection) { server, err := NewServer(context.TODO(), SimpleHubFactory(hubProto), testLoggerOption(), ChanReceiveTimeout(200*time.Millisecond), - StreamBufferCapacity(5)) + StreamBufferCapacity(5), + WithAlternateMethodName("Simple", "."), + ) if err != nil { Fail(err.Error()) return nil, nil