Skip to content

Commit

Permalink
Merge pull request #53 from muzzammilshahid/progressive-call-results
Browse files Browse the repository at this point in the history
Implement progressive call results
  • Loading branch information
muzzammilshahid authored Sep 20, 2024
2 parents 432893e + 227026d commit 8226aa1
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 24 deletions.
40 changes: 40 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,43 @@ func TestPublishSubscribe(t *testing.T) {
log.Println(event)
})
}

func TestProgressiveCallResults(t *testing.T) {
session := connect(t)

reg, err := session.Register(
"foo.bar.progress",
func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result {
// Send progress
for i := 1; i <= 3; i++ {
err := invocation.SendProgress([]any{i}, nil)
require.NoError(t, err)
}

// Return final result
return &xconn.Result{Arguments: []any{"done"}}
},
nil,
)
require.NoError(t, err)
require.NotNil(t, reg)

t.Run("ProgressiveCall", func(t *testing.T) {
// Store received progress updates
progressUpdates := make([]int, 0)

result, err := session.CallProgress(context.Background(), "foo.bar.progress", nil, nil, nil,
func(progressiveResult *xconn.Result) {
progress := int(progressiveResult.Arguments[0].(float64))
// Collect received progress
progressUpdates = append(progressUpdates, progress)
})
require.NoError(t, err)

// Verify progressive updates received correctly
require.Equal(t, []int{1, 2, 3}, progressUpdates)

// Verify the final result
require.Equal(t, "done", result.Arguments[0])
})
}
50 changes: 50 additions & 0 deletions examples/rpc_progressive_call_results/callee/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

import (
"context"
"log"
"os"
"os/signal"
"time"

"github.com/xconnio/xconn-go"
)

const procedureProgressDownload = "io.xconn.progress.download"

func main() {
// Create and connect a callee client to server
ctx := context.Background()
client := xconn.Client{}
callee, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1")
if err != nil {
log.Fatalf("Failed to connect to server: %s", err)
}
defer func() { _ = callee.Leave() }()

invocationHandler := func(ctx context.Context, invocation *xconn.Invocation) *xconn.Result {
fileSize := 100 // Simulate a file size of 100 units
for i := 0; i <= fileSize; i += 10 {
progress := i * 100 / fileSize
if err := invocation.SendProgress([]any{progress}, nil); err != nil {
return &xconn.Result{Err: "wamp.error.canceled", Arguments: []any{err.Error()}}
}
time.Sleep(500 * time.Millisecond) // Simulate time taken for download
}

return &xconn.Result{Arguments: []any{"Download complete!"}}
}

registration, err := callee.Register(procedureProgressDownload, invocationHandler, nil)
if err != nil {
log.Fatalf("Failed to register method: %s", err)
}
defer func() { _ = callee.Unregister(registration.ID) }()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
select {
case <-sigChan:
case <-ctx.Done():
}
}
32 changes: 32 additions & 0 deletions examples/rpc_progressive_call_results/caller/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package main

import (
"context"
"fmt"
"log"

"github.com/xconnio/xconn-go"
)

const procedureProgressDownload = "io.xconn.progress.download"

func main() {
// Create and connect a caller client to server
ctx := context.Background()
client := xconn.Client{}
caller, err := client.Connect(ctx, "ws://localhost:8080/ws", "realm1")
if err != nil {
log.Fatalf("Failed to connect to server: %s", err)
}
defer func() { _ = caller.Leave() }()

result, err := caller.CallProgress(ctx, procedureProgressDownload, nil, nil, nil, func(result *xconn.Result) {
progress := result.Arguments[0]
fmt.Printf("Download progress: %v%%\n", progress)
})
if err != nil {
log.Fatalf("Call failed: %s", err)
}

fmt.Println(result.Arguments[0])
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/gobwas/ws v1.4.0
github.com/projectdiscovery/ratelimit v0.0.50
github.com/stretchr/testify v1.9.0
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d
gopkg.in/yaml.v3 v3.0.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0 h1:IU8Sn5EkI0vO/r8C36lEWM3wiSMi7NpxNmpUmt2fUyg=
github.com/xconnio/wampproto-go v0.0.0-20240801143427-b722ee9231d0/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4=
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54 h1:uqKiqnmD6XSnX65WbUUNmIyW4L0oaPeOQPytzrxZPyg=
github.com/xconnio/wampproto-go v0.0.0-20240920091217-fd8f83f21c54/go.mod h1:/b7EyR1X9EkOHPQBJGz1KvdjClo1GsalBGIzjQU5+i4=
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9 h1:N0W6uTElFFj/nl88fAtCwUw0y0pdHbtn3QPQri/iGsw=
github.com/xconnio/wampproto-protobuf/go v0.0.0-20240706133816-0ca5f0268ce9/go.mod h1:k3t5aYBC+1ujppNAaIgu+Kn7oryRSwsP3o362HkAAho=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
Expand Down
98 changes: 77 additions & 21 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type InvocationHandler func(ctx context.Context, invocation *Invocation) *Result
type EventHandler func(event *Event)
type ProgressHandler func(result *Result)

type Session struct {
base BaseSession
Expand All @@ -29,6 +30,7 @@ type Session struct {
unregisterRequests sync.Map
registrations sync.Map
callRequests sync.Map
progressHandlers sync.Map

// publish subscribe data structures
subscribeRequests sync.Map
Expand All @@ -52,6 +54,7 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session {
unregisterRequests: sync.Map{},
registrations: sync.Map{},
callRequests: sync.Map{},
progressHandlers: sync.Map{},

subscribeRequests: sync.Map{},
unsubscribeRequests: sync.Map{},
Expand Down Expand Up @@ -120,8 +123,23 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
return fmt.Errorf("received RESULT for unknown request")
}

req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
progress, _ := result.Details()[wampproto.OptionProgress].(bool)
if progress {
progressHandler, exists := s.progressHandlers.Load(result.RequestID())
if exists {
progHandler := progressHandler.(ProgressHandler)
progHandler(&Result{
Arguments: result.Args(),
KwArguments: result.KwArgs(),
Details: result.Details(),
})
}
} else {
req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
s.progressHandlers.Delete(result.RequestID())
}

case messages.MessageTypeInvocation:
invocation := msg.(*messages.Invocation)
end, _ := s.registrations.Load(invocation.RegistrationID())
Expand All @@ -133,24 +151,45 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
Details: invocation.Details(),
}

var msgToSend messages.Message
res := endpoint(context.Background(), inv)
if res.Err != "" {
msgToSend = messages.NewError(
int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments,
)
} else {
msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments)
receiveProgress, _ := invocation.Details()[wampproto.OptionReceiveProgress].(bool)
if receiveProgress {
inv.SendProgress = func(arguments []any, kwArguments map[string]any) error {
yield := messages.NewYield(invocation.RequestID(), map[string]any{"progress": true}, arguments, kwArguments)
payload, err := s.proto.SendMessage(yield)
if err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}

if err = s.base.Write(payload); err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
return nil
}
}

payload, err := s.proto.SendMessage(msgToSend)
if err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
go func() {
var msgToSend messages.Message
res := endpoint(context.Background(), inv)
if res.Err != "" {
msgToSend = messages.NewError(
int64(invocation.Type()), invocation.RequestID(), map[string]any{}, res.Err, res.Arguments, res.KwArguments,
)
} else {
msgToSend = messages.NewYield(invocation.RequestID(), nil, res.Arguments, res.KwArguments)
}

payload, err := s.proto.SendMessage(msgToSend)
if err != nil {
log.Println("failed to send yield: %w", err)
return
}

if err = s.base.Write(payload); err != nil {
log.Println("failed to send yield: %w", err)
return
}
}()

if err = s.base.Write(payload); err != nil {
return fmt.Errorf("failed to send yield: %w", err)
}
case messages.MessageTypeSubscribed:
subscribed := msg.(*messages.Subscribed)
request, exists := s.subscribeRequests.Load(subscribed.RequestID())
Expand Down Expand Up @@ -323,10 +362,7 @@ func (s *Session) Unregister(registrationID int64) error {
}
}

func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
func (s *Session) call(ctx context.Context, call *messages.Call) (*Result, error) {
toSend, err := s.proto.SendMessage(call)
if err != nil {
return nil, err
Expand Down Expand Up @@ -356,6 +392,26 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs
}
}

func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
return s.call(ctx, call)
}

func (s *Session) CallProgress(ctx context.Context, procedure string, args []any, kwArgs map[string]any,
options map[string]any, progressHandler ProgressHandler) (*Result, error) {

call := messages.NewCall(s.idGen.NextID(), options, procedure, args, kwArgs)
if progressHandler == nil {
progressHandler = func(result *Result) {}
}
s.progressHandlers.Store(call.RequestID(), progressHandler)
call.Options()[wampproto.OptionReceiveProgress] = true

return s.call(ctx, call)
}

func (s *Session) Subscribe(topic string, handler EventHandler, options map[string]any) (*Subscription, error) {
subscribe := messages.NewSubscribe(s.idGen.NextID(), options, topic)
toSend, err := s.proto.SendMessage(subscribe)
Expand Down
4 changes: 4 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ type Event struct {
Details map[string]any
}

type SendProgress func(arguments []any, kwArguments map[string]any) error

type Invocation struct {
Procedure string
Arguments []any
KwArguments map[string]any
Details map[string]any

SendProgress SendProgress
}

type Result struct {
Expand Down

0 comments on commit 8226aa1

Please sign in to comment.