diff --git a/broker/broker.go b/broker/broker.go index d4268ee3..12516f63 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -26,17 +26,14 @@ type Cacheable interface { ToCacheEntry() ([]byte, error) } -type PresenceInfo struct { - // Total number of present clients (uniq) - Total int - // Presence records - Records []interface{} -} - // We can extend the presence read functionality in the future // (e.g., add pagination, filtering, etc.) type PresenceInfoOptions struct { - ReturnRecords bool + ReturnRecords bool `json:"return_records,omitempty"` +} + +func NewPresenceInfoOptions() *PresenceInfoOptions { + return &PresenceInfoOptions{ReturnRecords: true} } type PresenceInfoOption func(*PresenceInfoOptions) @@ -83,14 +80,17 @@ type Broker interface { // Adds a new presence record for the stream. Returns true if that's the first // presence record for the presence ID (pid, a unique user presence identifier). - PresenceAdd(stream string, sid string, pid string, info interface{}) error + PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) // Removes a presence record for the stream. Returns true if that was the last // record for the presence ID (pid). - PresenceRemove(stream string, sid string, pid string) error + PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) // Retrieves presence information for the stream (counts, records, etc. depending on the options) - PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) + PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) + + // Marks presence record as finished (for cache expiration) + FinishPresence(sid string) error } // LocalBroker is a single-node broker that can used to store streams data locally @@ -229,14 +229,18 @@ func (LegacyBroker) FinishSession(sid string) error { return nil } -func (LegacyBroker) PresenceAdd(stream string, sid string, pid string, info interface{}) error { - return errors.New("presence not supported") +func (LegacyBroker) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") } -func (LegacyBroker) PresenceRemove(stream string, sid string, pid string) error { - return errors.New("presence not supported") +func (LegacyBroker) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") } -func (LegacyBroker) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { +func (LegacyBroker) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { return nil, errors.New("presence not supported") } + +func (LegacyBroker) FinishPresence(sid string) error { + return nil +} diff --git a/broker/memory.go b/broker/memory.go index f0ed9f61..acf65826 100644 --- a/broker/memory.go +++ b/broker/memory.go @@ -4,10 +4,12 @@ import ( "context" "errors" "fmt" + "slices" "sync" "time" "github.com/anycable/anycable-go/common" + "github.com/anycable/anycable-go/utils" nanoid "github.com/matoous/go-nanoid" ) @@ -178,6 +180,7 @@ type presenceSessionEntry struct { type presenceEntry struct { info interface{} + id string sessions []string } @@ -200,6 +203,14 @@ func (pe *presenceEntry) remove(sid string) bool { return len(pe.sessions) == 0 } +func (pe *presenceEntry) add(sid string, info interface{}) { + if !slices.Contains(pe.sessions, sid) { + pe.sessions = append(pe.sessions, sid) + } + + pe.info = info +} + type presenceState struct { streams map[string]map[string]*presenceEntry sessions map[string]*presenceSessionEntry @@ -431,6 +442,10 @@ func (b *Memory) FinishSession(sid string) error { } b.sessionsMu.Unlock() + return nil +} + +func (b *Memory) FinishPresence(sid string) error { b.presence.mu.Lock() if sp, ok := b.presence.sessions[sid]; ok { @@ -442,7 +457,7 @@ func (b *Memory) FinishSession(sid string) error { return nil } -func (b *Memory) PresenceAdd(stream string, sid string, pid string, info interface{}) error { +func (b *Memory) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { b.presence.mu.Lock() defer b.presence.mu.Unlock() @@ -450,54 +465,74 @@ func (b *Memory) PresenceAdd(stream string, sid string, pid string, info interfa b.presence.streams[stream] = make(map[string]*presenceEntry) } + if _, ok := b.presence.sessions[sid]; !ok { + b.presence.sessions[sid] = &presenceSessionEntry{ + streams: make(map[string]string), + } + } + + if oldPid, ok := b.presence.sessions[sid].streams[stream]; ok && oldPid != pid { + return nil, errors.New("presence ID mismatch") + } + + b.presence.sessions[sid].streams[stream] = pid + streamPresence := b.presence.streams[stream] + newPresence := false + if _, ok := streamPresence[pid]; !ok { + newPresence = true streamPresence[pid] = &presenceEntry{ info: info, + id: pid, sessions: []string{}, } } streamSessionPresence := streamPresence[pid] - newPresence := len(streamSessionPresence.sessions) == 0 - - streamSessionPresence.sessions = append( - streamSessionPresence.sessions, - sid, - ) - - if _, ok := b.presence.sessions[sid]; !ok { - b.presence.sessions[sid] = &presenceSessionEntry{ - streams: make(map[string]string), - } - } - - b.presence.sessions[sid].streams[stream] = pid + streamSessionPresence.add(sid, info) if newPresence { - b.broadcaster.Broadcast(&common.StreamMessage{ - Stream: stream, - Data: common.PresenceJoinMessage(pid, info), - }) + return &common.PresenceEvent{ + Type: common.PresenceJoinType, + ID: pid, + Info: info, + }, nil } - return nil + return nil, nil } -func (b *Memory) PresenceRemove(stream string, sid string, pid string) error { +func (b *Memory) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { b.presence.mu.Lock() defer b.presence.mu.Unlock() if _, ok := b.presence.streams[stream]; !ok { - return nil + return nil, errors.New("stream not found") + } + + var pid string + + if ses, ok := b.presence.sessions[sid]; ok { + if id, ok := ses.streams[stream]; !ok { + return nil, errors.New("presence info not found") + } else { + pid = id + } + + delete(ses.streams, stream) + + if len(ses.streams) == 0 { + delete(b.presence.sessions, sid) + } } streamPresence := b.presence.streams[stream] if _, ok := streamPresence[pid]; !ok { - return nil + return nil, errors.New("presence record not found") } streamSessionPresence := streamPresence[pid] @@ -512,22 +547,18 @@ func (b *Memory) PresenceRemove(stream string, sid string, pid string) error { delete(b.presence.streams, stream) } - if _, ok := b.presence.sessions[sid]; ok { - delete(b.presence.sessions[sid].streams, stream) - } - if empty { - b.broadcaster.Broadcast(&common.StreamMessage{ - Stream: stream, - Data: common.PresenceLeaveMessage(pid), - }) + return &common.PresenceEvent{ + Type: common.PresenceLeaveType, + ID: pid, + }, nil } - return nil + return nil, nil } -func (b *Memory) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { - options := &PresenceInfoOptions{} +func (b *Memory) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { + options := NewPresenceInfoOptions() for _, opt := range opts { opt(options) } @@ -535,19 +566,24 @@ func (b *Memory) PresenceInfo(stream string, opts ...PresenceInfoOption) (*Prese b.presence.mu.RLock() defer b.presence.mu.RUnlock() + info := common.NewPresenceInfo() + if _, ok := b.presence.streams[stream]; !ok { - return &PresenceInfo{Total: 0}, nil + return info, nil } streamPresence := b.presence.streams[stream] - info := &PresenceInfo{Total: len(streamPresence)} + info.Total = len(streamPresence) if options.ReturnRecords { - info.Records = make([]interface{}, 0, len(streamPresence)) + info.Records = make([]*common.PresenceEvent, 0, len(streamPresence)) for _, entry := range streamPresence { - info.Records = append(info.Records, entry.info) + info.Records = append(info.Records, &common.PresenceEvent{ + Info: entry.info, + ID: entry.id, + }) } } @@ -635,7 +671,7 @@ func (b *Memory) expirePresence() { toDelete := []string{} for sid, sp := range b.presence.sessions { - if sp.deadline < now { + if sp.deadline > 0 && sp.deadline < now { toDelete = append(toDelete, sid) } } @@ -661,9 +697,15 @@ func (b *Memory) expirePresence() { if empty { delete(b.presence.streams[stream], pid) + msg := &common.PresenceEvent{Type: common.PresenceLeaveType, ID: pid} + leaveMessages = append(leaveMessages, common.StreamMessage{ Stream: stream, - Data: common.PresenceLeaveMessage(pid), + Data: string(utils.ToJSON(msg)), + Meta: &common.StreamMessageMetadata{ + BroadcastType: common.PresenceType, + Transient: true, + }, }) if len(b.presence.streams[stream]) == 0 { @@ -677,8 +719,11 @@ func (b *Memory) expirePresence() { b.presence.mu.Unlock() - // TODO: batch broadcast? - for _, msg := range leaveMessages { - b.broadcaster.Broadcast(&msg) + if b.broadcaster != nil { + // TODO: batch broadcast? + // FIXME: move broadcasts out of broker + for _, msg := range leaveMessages { + b.broadcaster.Broadcast(&msg) + } } } diff --git a/broker/memory_test.go b/broker/memory_test.go index f0a32820..6d8bc6ee 100644 --- a/broker/memory_test.go +++ b/broker/memory_test.go @@ -1,9 +1,11 @@ package broker import ( + "slices" "testing" "time" + "github.com/anycable/anycable-go/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -149,3 +151,93 @@ func TestMemstream_filterByOffset(t *testing.T) { }) require.Error(t, err) } + +func TestMemory_Presence(t *testing.T) { + config := NewConfig() + + broker := NewMemoryBroker(nil, &config) + + ev, err := broker.PresenceAdd("a", "s1", "user_1", map[string]interface{}{"name": "John"}) + require.NoError(t, err) + + assert.Equal(t, "user_1", ev.ID) + assert.Equal(t, "join", ev.Type) + assert.Equal(t, map[string]interface{}{"name": "John"}, ev.Info) + + // Adding presence for the same session with different ID is illegal + ev, err = broker.PresenceAdd("a", "s1", "user_2", map[string]interface{}{"name": "Boo"}) + require.Error(t, err) + assert.Nil(t, ev) + + ev, err = broker.PresenceAdd("a", "s2", "user_1", map[string]interface{}{"name": "Jack"}) + require.NoError(t, err) + assert.Nil(t, ev) + + ev, err = broker.PresenceAdd("a", "s3", "user_2", map[string]interface{}{"name": "Alice"}) + require.NoError(t, err) + assert.Equal(t, "user_2", ev.ID) + + ev, err = broker.PresenceAdd("b", "s3", "user_2", map[string]interface{}{"name": "Alice"}) + require.NoError(t, err) + assert.Equal(t, "user_2", ev.ID) + + info, err := broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + assert.Equal(t, 2, len(info.Records)) + + // Make sure the latest info is returned + assert.Truef(t, slices.ContainsFunc(info.Records, func(r *common.PresenceEvent) bool { + return r.ID == "user_1" && (r.Info.(map[string]interface{})["name"] == "Jack") + }), "presence user with user_id and name:Jack not found: %s", info.Records) + + // Now let's check that leave works + ev, err = broker.PresenceRemove("a", "s1") + require.NoError(t, err) + assert.Nil(t, ev) + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + + ev, err = broker.PresenceRemove("a", "s2") + require.NoError(t, err) + assert.Equal(t, "user_1", ev.ID) + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 1, info.Total) +} + +func TestMemory_expirePresence(t *testing.T) { + config := NewConfig() + config.PresenceTTL = 1 + + broker := NewMemoryBroker(nil, &config) + + broker.PresenceAdd("a", "s1", "user_1", "john") // nolint: errcheck + broker.PresenceAdd("a", "s2", "user_2", "kate") // nolint: errcheck + + info, err := broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + + broker.FinishPresence("s1") // nolint: errcheck + broker.FinishPresence("s2") // nolint: errcheck + + time.Sleep(2 * time.Second) + + broker.PresenceAdd("a", "s3", "user_1", "john") // nolint: errcheck + + broker.expire() + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 1, info.Total) + assert.Equal(t, "user_1", info.Records[0].ID) +} diff --git a/broker/nats.go b/broker/nats.go index ce2d09a5..5e5c8019 100644 --- a/broker/nats.go +++ b/broker/nats.go @@ -482,18 +482,22 @@ func (n *NATS) Reset() error { return nil } -func (n *NATS) PresenceAdd(stream string, sid string, pid string, info interface{}) error { - return errors.New("presence not supported") +func (n *NATS) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") } -func (n *NATS) PresenceRemove(stream string, sid string, pid string) error { - return errors.New("presence not supported") +func (n *NATS) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") } -func (n *NATS) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { +func (n *NATS) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { return nil, errors.New("presence not supported") } +func (n *NATS) FinishPresence(sid string) error { + return nil +} + func (n *NATS) add(stream string, data string) (uint64, error) { err := n.ensureStreamExists(stream) diff --git a/cli/options.go b/cli/options.go index 38b0ea71..e15c113b 100644 --- a/cli/options.go +++ b/cli/options.go @@ -637,6 +637,12 @@ func brokerCLIFlags(c *config.Config) []cli.Flag { Value: c.Broker.SessionsTTL, Destination: &c.Broker.SessionsTTL, }, + &cli.Int64Flag{ + Name: "presence_ttl", + Usage: "TTL for presence information (seconds)", + Value: c.Broker.PresenceTTL, + Destination: &c.Broker.PresenceTTL, + }, }) } diff --git a/common/common.go b/common/common.go index 2f63ec63..feae779d 100644 --- a/common/common.go +++ b/common/common.go @@ -61,14 +61,15 @@ const ( RejectedType = "reject_subscription" // Not supported by Action Cable currently UnsubscribedType = "unsubscribed" + ErrorType = "error" HistoryConfirmedType = "confirm_history" HistoryRejectedType = "reject_history" + PresenceInfoType = "info" PresenceJoinType = "join" PresenceLeaveType = "leave" - PresenceInfoType = "presence" - PresenceErrorType = "presence_error" + PresenceType = "presence" WhisperType = "whisper" ) @@ -223,12 +224,25 @@ func (c *ConnectResult) ToCallResult() *CallResult { return &res } -type PresenceInfo struct { +type PresenceEvent struct { Type string `json:"type,omitempty"` Info interface{} `json:"info,omitempty"` ID string `json:"id"` } +type PresenceInfo struct { + // Type is always "presence" + Type string `json:"type,omitempty"` + // Total number of present clients (uniq) + Total int `json:"total"` + // Presence records + Records []*PresenceEvent `json:"records,omitempty"` +} + +func NewPresenceInfo() *PresenceInfo { + return &PresenceInfo{Type: PresenceInfoType} +} + // CommandResult is a result of performing controller action, // which contains informations about streams to subscribe, // messages to sent and broadcast. @@ -240,7 +254,7 @@ type CommandResult struct { StoppedStreams []string Transmissions []string Broadcasts []*StreamMessage - Presence *PresenceInfo + Presence *PresenceEvent CState map[string]string IState map[string]string DisconnectInterest int @@ -314,7 +328,7 @@ type Message struct { Identifier string `json:"identifier"` Data interface{} `json:"data,omitempty"` History HistoryRequest `json:"history,omitempty"` - Presence *PresenceInfo `json:"presence,omitempty"` + Presence *PresenceEvent `json:"presence,omitempty"` } func (m *Message) LogValue() slog.Value { @@ -478,18 +492,18 @@ func NewDisconnectMessage(reason string, reconnect bool) *DisconnectMessage { // Reply represents an outgoing client message type Reply struct { - Type string `json:"type,omitempty"` - Identifier string `json:"identifier,omitempty"` - Message interface{} `json:"message,omitempty"` - Presence *PresenceInfo `json:"presence,omitempty"` - Reason string `json:"reason,omitempty"` - Reconnect bool `json:"reconnect,omitempty"` - StreamID string `json:"stream_id,omitempty"` - Epoch string `json:"epoch,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Sid string `json:"sid,omitempty"` - Restored bool `json:"restored,omitempty"` - RestoredIDs []string `json:"restored_ids,omitempty"` + Type string `json:"type,omitempty"` + Identifier string `json:"identifier,omitempty"` + Message interface{} `json:"message,omitempty"` + Presence *PresenceEvent `json:"presence,omitempty"` + Reason string `json:"reason,omitempty"` + Reconnect bool `json:"reconnect,omitempty"` + StreamID string `json:"stream_id,omitempty"` + Epoch string `json:"epoch,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Sid string `json:"sid,omitempty"` + Restored bool `json:"restored,omitempty"` + RestoredIDs []string `json:"restored_ids,omitempty"` } func (r *Reply) LogValue() slog.Value { @@ -576,13 +590,3 @@ func RejectionMessage(identifier string) string { func DisconnectionMessage(reason string, reconnect bool) string { return string(utils.ToJSON(DisconnectMessage{Type: DisconnectType, Reason: reason, Reconnect: reconnect})) } - -// PresenceJoinMessage returns a presence message for the specified event and data -func PresenceJoinMessage(id string, info interface{}) string { - return string(utils.ToJSON(Reply{Type: PresenceJoinType, Presence: &PresenceInfo{ID: id, Info: info}})) -} - -// PresenceLeaveMessage returns a presence message for the specified event and data -func PresenceLeaveMessage(id string) string { - return string(utils.ToJSON(Reply{Type: PresenceLeaveType, Presence: &PresenceInfo{ID: id}})) -} diff --git a/docs/.trieve.yml b/docs/.trieve.yml index 12565e0d..a5815d9e 100644 --- a/docs/.trieve.yml +++ b/docs/.trieve.yml @@ -32,6 +32,7 @@ pages: - "./broker.md" - "./configuration.md" - "./embedded_nats.md" + - "./presence.md" - "./getting_started.md" - "./health_checking.md" - "./instrumentation.md" diff --git a/docs/Readme.md b/docs/Readme.md index dcf6e8b5..6909c2dc 100644 --- a/docs/Readme.md +++ b/docs/Readme.md @@ -12,5 +12,6 @@ * [Binary formats](binary_formats.md) * [JWT identification](jwt_identification.md) * [Signed streams](signed_streams.md) +* [Presence](presence.md) * [Embedded NATS](embedded_nats.md) * [Using as a library](library.md) diff --git a/docs/broker.md b/docs/broker.md index 3cbe2a88..e1ffe820 100644 --- a/docs/broker.md +++ b/docs/broker.md @@ -1,11 +1,12 @@ # Broker deep dive -Broker is a component of AnyCable-Go responsible for keeping streams and sessions information in a cache-like storage. It drives the [Reliable Streams](./reliable_streams.md) feature. +Broker is a component of AnyCable-Go responsible for keeping streams, sessions and presence information in a cache-like storage. It drives the [Reliable Streams](./reliable_streams.md) and [Presence](./presence.md) features. Broker implements features that can be characterized as _hot cache utilities_: - Handling incoming broadcast messages and storing them in a cache—that could help clients to receive missing broadcasts (triggered while the client was offline, for example). - Persisting client states—to make it possible to restore on re-connection (by providing a _session id_ of the previous connection). +- Keeping per-channel presence information. ## Client-server communication diff --git a/docs/presence.md b/docs/presence.md new file mode 100644 index 00000000..06c99305 --- /dev/null +++ b/docs/presence.md @@ -0,0 +1,74 @@ +# Presence tracking + +AnyCable comes with a built-in presence tracking support for your real-time applications. No need to write custom code and deal with storage mechanisms to know who's online in your channels. + +## Overview + +AnyCable presence allows channel subscribers to share their presence information with other clients and track the changes in the channel's presence set. Presence data can be used to display a list of online users, track user activity, etc. + +## Quick start + +Presence is a part of the [broker](./broker.md) component, so you must enable it either via the `broker` configuration preset or manually: + +```sh +$ anycable-go --presets=broker + +# or + +$ anycable-go --broker=memory +``` + +> 🚧 Currently, presence tracking is only supported by the memory broker. + +Now, you can use the presence API in your application. For example, using [AnyCable JS client](https://github.com/anycable/anycable-client): + +```js +import { createCable } from '@anycable/web' +// or for non-web projects +// import { createCable } from '@anycable/core' + +const cable = createCable({protocol: 'actioncable-v1-ext-json'}) + +const channel = cable.streamFrom('room/42'); + +// join the channel's presence set +channel.presence.join(user.id, { name: user.name }) + +// get the current presence state +const presence = await chatChannel.presence.info() + +// subscribe to presence events +channel.on("presence", (event) => { + const { type, info, id } = event + + if (type === "join") { + console.log(`${info.name} joined the channel`) + } else if (type === "leave") { + console.log(`${id} left the channel`) + } +}) +``` + +## Presence lifecycle + +Clients join the presence set explicitly by performing the `presence` command. The `join` event is sent to all subscribers (including the initiator) with the presence information, but only if the **presence ID** (provided by the client) hasn't been registered yet. Thus, multiple sessions with the same ID are treated as a single presence record. + +Clients may explicitly leave the presence set by performing the `leave` command or by unsubscribing from the channel. The `leave` event is sent to all subscribers only if no other sessions with the same ID are left in the presence set. + +When a client disconnects without explicitly leaving or unsubscribing the channel, it's present information stays in the set for a short period of time. That prevents the burst of `join` / `leave` events when the client reconnects frequently. + +## Configuration + +You can configure the presence expiration time (for disconnected clients) via the `--presence_ttl` option. The default value is 15 seconds. + +## Presence for channels + +> 🚧 Presence for channels is to be implemented. You can only use presence with [signed streams](./signed_streams.md) for now. + +## Presence API + +> 🚧 Presence REST API is to be implemented. You can only use the presence API via the WebSocket connection. + +## Presence webhooks + +> 🚧 Presence webhooks are to be implemented, too. diff --git a/forspell.dict b/forspell.dict index 6b882f6b..366e4af8 100644 --- a/forspell.dict +++ b/forspell.dict @@ -67,3 +67,4 @@ reconnection norpc noauth jid +webhooks diff --git a/mocks/Broker.go b/mocks/Broker.go index ada5a4ca..422a51bf 100644 --- a/mocks/Broker.go +++ b/mocks/Broker.go @@ -52,6 +52,24 @@ func (_m *Broker) CommitSession(sid string, session broker.Cacheable) error { return r0 } +// FinishPresence provides a mock function with given fields: sid +func (_m *Broker) FinishPresence(sid string) error { + ret := _m.Called(sid) + + if len(ret) == 0 { + panic("no return value specified for FinishPresence") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(sid) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // FinishSession provides a mock function with given fields: sid func (_m *Broker) FinishSession(sid string) error { ret := _m.Called(sid) @@ -141,25 +159,37 @@ func (_m *Broker) HistorySince(stream string, ts int64) ([]common.StreamMessage, } // PresenceAdd provides a mock function with given fields: stream, sid, pid, info -func (_m *Broker) PresenceAdd(stream string, sid string, pid string, info interface{}) error { +func (_m *Broker) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { ret := _m.Called(stream, sid, pid, info) if len(ret) == 0 { panic("no return value specified for PresenceAdd") } - var r0 error - if rf, ok := ret.Get(0).(func(string, string, string, interface{}) error); ok { + var r0 *common.PresenceEvent + var r1 error + if rf, ok := ret.Get(0).(func(string, string, string, interface{}) (*common.PresenceEvent, error)); ok { + return rf(stream, sid, pid, info) + } + if rf, ok := ret.Get(0).(func(string, string, string, interface{}) *common.PresenceEvent); ok { r0 = rf(stream, sid, pid, info) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.PresenceEvent) + } } - return r0 + if rf, ok := ret.Get(1).(func(string, string, string, interface{}) error); ok { + r1 = rf(stream, sid, pid, info) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // PresenceInfo provides a mock function with given fields: stream, opts -func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) (*broker.PresenceInfo, error) { +func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) (*common.PresenceInfo, error) { _va := make([]interface{}, len(opts)) for _i := range opts { _va[_i] = opts[_i] @@ -173,16 +203,16 @@ func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) panic("no return value specified for PresenceInfo") } - var r0 *broker.PresenceInfo + var r0 *common.PresenceInfo var r1 error - if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) (*broker.PresenceInfo, error)); ok { + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) (*common.PresenceInfo, error)); ok { return rf(stream, opts...) } - if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) *broker.PresenceInfo); ok { + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) *common.PresenceInfo); ok { r0 = rf(stream, opts...) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*broker.PresenceInfo) + r0 = ret.Get(0).(*common.PresenceInfo) } } @@ -195,22 +225,34 @@ func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) return r0, r1 } -// PresenceRemove provides a mock function with given fields: stream, sid, pid -func (_m *Broker) PresenceRemove(stream string, sid string, pid string) error { - ret := _m.Called(stream, sid, pid) +// PresenceRemove provides a mock function with given fields: stream, sid +func (_m *Broker) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + ret := _m.Called(stream, sid) if len(ret) == 0 { panic("no return value specified for PresenceRemove") } - var r0 error - if rf, ok := ret.Get(0).(func(string, string, string) error); ok { - r0 = rf(stream, sid, pid) + var r0 *common.PresenceEvent + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*common.PresenceEvent, error)); ok { + return rf(stream, sid) + } + if rf, ok := ret.Get(0).(func(string, string) *common.PresenceEvent); ok { + r0 = rf(stream, sid) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.PresenceEvent) + } } - return r0 + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(stream, sid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // RestoreSession provides a mock function with given fields: from diff --git a/node/broker_integration_test.go b/node/broker_integration_test.go index 95f70b1e..c2218a16 100644 --- a/node/broker_integration_test.go +++ b/node/broker_integration_test.go @@ -393,6 +393,164 @@ func sharedIntegrationHistory(t *testing.T, node *Node, controller *mocks.Contro }) } +// A test to verify the presence flow. +// +// SETUP: +// - Two sessions are created (authenticated and subscribed with presence stream). +// +// TEST 1 — join/leave events and info: +// - First session joins the channel. +// - Both sessions receive join event. +// - Second session requests presence info. +// - Second session joins the channel. +// - First session leaves the channel. +// - Both sessions receive leave event. +// - Second session requests presence info. +// +// TEST 2 — presence expiration: +// - Both sessions joins the channel. +// - Both sessions receive join event. +// - First session disconnects. +// - Wait for expiration. +// - Second session receives leave event. +// - Second session requests presence info. +func TestIntegrationPresence_Memory(t *testing.T) { + node, controller := setupIntegrationNode() + + bconf := broker.NewConfig() + bconf.PresenceTTL = 2 + + subscriber := pubsub.NewLegacySubscriber(node) + + br := broker.NewMemoryBroker(subscriber, &bconf) + node.SetBroker(br) + + require.NoError(t, br.Start(nil)) + + go node.Start() // nolint:errcheck + defer node.Shutdown(context.Background()) // nolint:errcheck + + sharedIntegrationPresence(t, node, controller) +} + +func sharedIntegrationPresence(t *testing.T, node *Node, controller *mocks.Controller) { + controller. + On("Subscribe", "sasha", mock.Anything, "sasha", "chat_1"). + Return(&common.CommandResult{ + Status: common.SUCCESS, + Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, + Streams: []string{"presence_1", "messages_1"}, + IState: map[string]string{common.PRESENCE_STREAM_STATE: "presence_1"}, + }, nil) + controller. + On("Subscribe", "mia", mock.Anything, "mia", "chat_1"). + Return(&common.CommandResult{ + Status: common.SUCCESS, + Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, + Streams: []string{"presence_1", "messages_1"}, + IState: map[string]string{common.PRESENCE_STREAM_STATE: "presence_1"}, + }, nil) + controller. + On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(&common.CommandResult{ + Status: common.SUCCESS, + }, nil) + + setupSessions := func() (*Session, *Session, func()) { + session := requireAuthenticatedSession(t, node, "sasha") + _, err := node.Subscribe( + session, + &common.Message{ + Identifier: "chat_1", + Command: "subscribe", + }) + require.NoError(t, err) + assertReceive(t, session, `{"type":"confirm","identifier":"chat_1"}`) + + session2 := requireAuthenticatedSession(t, node, "mia") + _, err = node.Subscribe( + session2, + &common.Message{ + Identifier: "chat_1", + Command: "subscribe", + }) + require.NoError(t, err) + assertReceive(t, session2, `{"type":"confirm","identifier":"chat_1"}`) + + return session, session2, func() { + node.Unsubscribe(session, &common.Message{Identifier: "chat_1", Command: "unsubscribe"}) // nolint:errcheck + node.Unsubscribe(session2, &common.Message{Identifier: "chat_1", Command: "unsubscribe"}) // nolint:errcheck + } + } + + t.Run("Join and leave", func(t *testing.T) { + sasha, mia, cleanup := setupSessions() + defer cleanup() + + err := node.PresenceJoin(sasha, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "42", Info: map[string]interface{}{"name": "Sasha"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"42","info":{"name":"Sasha"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"42","info":{"name":"Sasha"},"type":"join"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"42","info":{"name":"Sasha"}}]}}`) + + err = node.PresenceJoin(mia, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "13", Info: map[string]interface{}{"name": "Mia"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"13","info":{"name":"Mia"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"13","info":{"name":"Mia"},"type":"join"}}`) + + err = node.Presence(sasha, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":2,"records":[{"id":"42","info":{"name":"Sasha"}},{"id":"13","info":{"name":"Mia"}}]}}`) + + err = node.PresenceLeave(sasha, &common.Message{Identifier: "chat_1"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"42","type":"leave"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"42","type":"leave"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"13","info":{"name":"Mia"}}]}}`) + }) + + t.Run("Presence expiration", func(t *testing.T) { + sasha, mia, cleanup := setupSessions() + defer cleanup() + + err := node.PresenceJoin(sasha, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "142", Info: map[string]interface{}{"name": "Rickie"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"142","info":{"name":"Rickie"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"142","info":{"name":"Rickie"},"type":"join"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"142","info":{"name":"Rickie"}}]}}`) + + err = node.Disconnect(sasha) + require.NoError(t, err) + + // Wait for expiration to happen + time.Sleep(4 * time.Second) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"142","type":"leave"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":0}}`) + }) +} + func setupIntegrationNode() (*Node, *mocks.Controller) { config := NewConfig() config.HubGopoolSize = 2 @@ -422,10 +580,19 @@ func assertReceive(t *testing.T, s *Session, expected string) { msg, err := s.conn.Read() require.NoError(t, err) + var parsedMessage map[string]interface{} + var expectedMessage map[string]interface{} + + err = json.Unmarshal(msg, &parsedMessage) + require.NoError(t, err) + + err = json.Unmarshal([]byte(expected), &expectedMessage) + require.NoError(t, err) + assert.Equal( t, - expected, - string(msg), + expectedMessage, + parsedMessage, ) } diff --git a/node/node.go b/node/node.go index b363d9d4..332fe19b 100644 --- a/node/node.go +++ b/node/node.go @@ -178,6 +178,10 @@ func (n *Node) HandleCommand(s *Session, msg *common.Message) (err error) { err = n.History(s, msg) case "presence": err = n.Presence(s, msg) + case "join": + err = n.PresenceJoin(s, msg) + case "leave": + err = n.PresenceLeave(s, msg) case "whisper": err = n.Whisper(s, msg) default: @@ -499,6 +503,8 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu s.Log.Debug("controller unsubscribe", "response", res, "err", err) + presenceStream := "" + if err != nil { if res == nil || res.Status == common.ERROR { return nil, errorx.Decorate(err, "failed to unsubscribe from %s", msg.Identifier) @@ -507,6 +513,8 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu // Make sure to remove all streams subscriptions res.StopAllStreams = true + presenceStream = s.env.GetChannelStateField(msg.Identifier, common.PRESENCE_STREAM_STATE) + s.env.RemoveChannelState(msg.Identifier) s.subscriptions.RemoveChannel(msg.Identifier) @@ -525,6 +533,13 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu } } + // Make sure presence is removed on explicit unsubscribe + if presenceStream != "" { + if _, err := n.broker.PresenceRemove(presenceStream, s.GetID()); err != nil { + s.Log.Error("failed to remove presence", "error", err) + } + } + return res, nil } @@ -702,7 +717,7 @@ func (n *Node) Presence(s *Session, msg *common.Message) error { s.smu.Unlock() - var options *broker.PresenceInfoOptions + options := broker.NewPresenceInfoOptions() if msg.Data != nil { buf, err := json.Marshal(&msg.Data) @@ -715,15 +730,16 @@ func (n *Node) Presence(s *Session, msg *common.Message) error { if err != nil { s.Send(&common.Reply{ - Type: common.PresenceErrorType, + Type: common.PresenceType, Identifier: msg.Identifier, + Message: &common.PresenceInfo{Type: common.ErrorType}, }) return err } s.Send(&common.Reply{ - Type: common.PresenceInfoType, + Type: common.PresenceType, Identifier: msg.Identifier, Message: info, }) @@ -731,6 +747,34 @@ func (n *Node) Presence(s *Session, msg *common.Message) error { return nil } +// PresenceJoin adds the session to the presence state for the specified identifier +func (n *Node) PresenceJoin(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + s.smu.Unlock() + + return n.handlePresenceReply(s, msg.Identifier, common.PresenceJoinType, msg.Presence) +} + +// PresenceLeave removes the session to the presence state for the specified identifier +func (n *Node) PresenceLeave(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + s.smu.Unlock() + + return n.handlePresenceReply(s, msg.Identifier, common.PresenceLeaveType, msg.Presence) +} + // Broadcast message to stream (locally) func (n *Node) Broadcast(msg *common.StreamMessage) { n.metrics.CounterIncrement(metricsBroadcastMsg) @@ -762,6 +806,8 @@ func (n *Node) Disconnect(s *Session) error { n.broker.FinishSession(s.GetID()) // nolint:errcheck } + n.broker.FinishPresence(s.GetID()) // nolint:errcheck + if n.IsShuttingDown() { // Make sure session is removed from hub, so we don't try to send // broadcast messages to them @@ -912,11 +958,7 @@ func (n *Node) handleCallReply(s *Session, reply *common.CallResult) bool { return isDirty } -func (n *Node) handlePresenceReply(s *Session, identifier string, event string, presence *common.PresenceInfo) error { - if presence == nil { - return nil - } - +func (n *Node) handlePresenceReply(s *Session, identifier string, event string, presence *common.PresenceEvent) error { // Check that the presence stream is configured (thus, the feature is enabled) env := s.GetEnv() if env == nil { @@ -932,15 +974,30 @@ func (n *Node) handlePresenceReply(s *Session, identifier string, event string, sid := s.GetID() var err error + var msg *common.PresenceEvent if event == common.PresenceJoinType { // nolint:gocritic - err = n.broker.PresenceAdd(stream, sid, presence.ID, presence.Info) + if presence == nil { + return errors.New("presence data is missing") + } + msg, err = n.broker.PresenceAdd(stream, sid, presence.ID, presence.Info) } else if event == common.PresenceLeaveType { - err = n.broker.PresenceRemove(stream, sid, presence.ID) + msg, err = n.broker.PresenceRemove(stream, sid) } else { return fmt.Errorf("unknown presence event: %s", event) } + if msg != nil { + n.Broadcast(&common.StreamMessage{ + Stream: stream, + Data: string(utils.ToJSON(msg)), + Meta: &common.StreamMessageMetadata{ + BroadcastType: common.PresenceType, + Transient: true, + }, + }) + } + return err } diff --git a/streams/controller.go b/streams/controller.go index 9e297cb1..eb16182d 100644 --- a/streams/controller.go +++ b/streams/controller.go @@ -109,13 +109,17 @@ func (c *Controller) Subscribe(sid string, env *common.SessionEnv, ids string, i c.log.With("identifier", identifier).Debug("verified", "stream", stream) } - state := map[string]string{} + var state map[string]string if request.whisper { + state = make(map[string]string) state[common.WHISPER_STREAM_STATE] = stream } if request.presence { + if state == nil { + state = make(map[string]string) + } state[common.PRESENCE_STREAM_STATE] = stream }