diff --git a/broker/broker.go b/broker/broker.go index 928af664..12516f63 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -26,8 +26,29 @@ type Cacheable interface { ToCacheEntry() ([]byte, error) } +// We can extend the presence read functionality in the future +// (e.g., add pagination, filtering, etc.) +type PresenceInfoOptions struct { + ReturnRecords bool `json:"return_records,omitempty"` +} + +func NewPresenceInfoOptions() *PresenceInfoOptions { + return &PresenceInfoOptions{ReturnRecords: true} +} + +type PresenceInfoOption func(*PresenceInfoOptions) + +func WithPresenceInfoOptions(opts *PresenceInfoOptions) PresenceInfoOption { + return func(o *PresenceInfoOptions) { + if opts != nil { + *o = *opts + } + } +} + // Broker is responsible for: // - Managing streams history. +// - Managing presence information. // - Keeping client states for recovery. // - Distributing broadcasts across nodes. // @@ -56,6 +77,20 @@ type Broker interface { RestoreSession(from string) ([]byte, error) // Marks session as finished (for cache expiration) FinishSession(sid string) error + + // Adds a new presence record for the stream. Returns true if that's the first + // presence record for the presence ID (pid, a unique user presence identifier). + PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) + + // Removes a presence record for the stream. Returns true if that was the last + // record for the presence ID (pid). + PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) + + // Retrieves presence information for the stream (counts, records, etc. depending on the options) + PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) + + // Marks presence record as finished (for cache expiration) + FinishPresence(sid string) error } // LocalBroker is a single-node broker that can used to store streams data locally @@ -193,3 +228,19 @@ func (LegacyBroker) RestoreSession(from string) ([]byte, error) { func (LegacyBroker) FinishSession(sid string) error { return nil } + +func (LegacyBroker) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") +} + +func (LegacyBroker) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") +} + +func (LegacyBroker) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { + return nil, errors.New("presence not supported") +} + +func (LegacyBroker) FinishPresence(sid string) error { + return nil +} diff --git a/broker/config.go b/broker/config.go index e058147d..d0780ea6 100644 --- a/broker/config.go +++ b/broker/config.go @@ -14,6 +14,8 @@ type Config struct { HistoryLimit int `toml:"history_limit"` // Sessions cache TTL in seconds (after disconnect) SessionsTTL int64 `toml:"sessions_ttl"` + // Presence expire TTL in seconds (after disconnect) + PresenceTTL int64 `toml:"presence_ttl"` } func NewConfig() Config { @@ -24,6 +26,8 @@ func NewConfig() Config { HistoryLimit: 100, // 5 minutes by default SessionsTTL: 5 * 60, + // 15 seconds by default + PresenceTTL: 15, } } @@ -46,6 +50,9 @@ func (c Config) ToToml() string { result.WriteString("# For how long to store sessions state for resumeability (seconds)\n") result.WriteString(fmt.Sprintf("sessions_ttl = %d\n", c.SessionsTTL)) + result.WriteString("# For how long to keep presence information after session disconnect (seconds)\n") + result.WriteString(fmt.Sprintf("presence_ttl = %d\n", c.PresenceTTL)) + result.WriteString("\n") return result.String() diff --git a/broker/memory.go b/broker/memory.go index 59f6b105..acf65826 100644 --- a/broker/memory.go +++ b/broker/memory.go @@ -4,10 +4,12 @@ import ( "context" "errors" "fmt" + "slices" "sync" "time" "github.com/anycable/anycable-go/common" + "github.com/anycable/anycable-go/utils" nanoid "github.com/matoous/go-nanoid" ) @@ -170,6 +172,59 @@ type expireSessionEntry struct { sid string } +type presenceSessionEntry struct { + // stream -> pid + streams map[string]string + deadline int64 +} + +type presenceEntry struct { + info interface{} + id string + sessions []string +} + +func (pe *presenceEntry) remove(sid string) bool { + i := -1 + + for idx, s := range pe.sessions { + if s == sid { + i = idx + break + } + } + + if i == -1 { + return false + } + + pe.sessions = append(pe.sessions[:i], pe.sessions[i+1:]...) + + return len(pe.sessions) == 0 +} + +func (pe *presenceEntry) add(sid string, info interface{}) { + if !slices.Contains(pe.sessions, sid) { + pe.sessions = append(pe.sessions, sid) + } + + pe.info = info +} + +type presenceState struct { + streams map[string]map[string]*presenceEntry + sessions map[string]*presenceSessionEntry + + mu sync.RWMutex +} + +func newPresenceState() *presenceState { + return &presenceState{ + streams: make(map[string]map[string]*presenceEntry), + sessions: make(map[string]*presenceSessionEntry), + } +} + type Memory struct { broadcaster Broadcaster config *Config @@ -179,6 +234,8 @@ type Memory struct { sessions map[string]*sessionEntry expireSessions []*expireSessionEntry + presence *presenceState + streamsMu sync.RWMutex sessionsMu sync.RWMutex epochMu sync.RWMutex @@ -195,6 +252,7 @@ func NewMemoryBroker(node Broadcaster, config *Config) *Memory { tracker: NewStreamsTracker(), streams: make(map[string]*memstream), sessions: make(map[string]*sessionEntry), + presence: newPresenceState(), epoch: epoch, } } @@ -376,18 +434,162 @@ func (b *Memory) RestoreSession(from string) ([]byte, error) { func (b *Memory) FinishSession(sid string) error { b.sessionsMu.Lock() - defer b.sessionsMu.Unlock() - if _, ok := b.sessions[sid]; ok { b.expireSessions = append( b.expireSessions, &expireSessionEntry{sid: sid, deadline: time.Now().Unix() + b.config.SessionsTTL}, ) } + b.sessionsMu.Unlock() return nil } +func (b *Memory) FinishPresence(sid string) error { + b.presence.mu.Lock() + + if sp, ok := b.presence.sessions[sid]; ok { + sp.deadline = time.Now().Unix() + b.config.PresenceTTL + } + + b.presence.mu.Unlock() + + return nil +} + +func (b *Memory) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + b.presence.mu.Lock() + defer b.presence.mu.Unlock() + + if _, ok := b.presence.streams[stream]; !ok { + b.presence.streams[stream] = make(map[string]*presenceEntry) + } + + if _, ok := b.presence.sessions[sid]; !ok { + b.presence.sessions[sid] = &presenceSessionEntry{ + streams: make(map[string]string), + } + } + + if oldPid, ok := b.presence.sessions[sid].streams[stream]; ok && oldPid != pid { + return nil, errors.New("presence ID mismatch") + } + + b.presence.sessions[sid].streams[stream] = pid + + streamPresence := b.presence.streams[stream] + + newPresence := false + + if _, ok := streamPresence[pid]; !ok { + newPresence = true + streamPresence[pid] = &presenceEntry{ + info: info, + id: pid, + sessions: []string{}, + } + } + + streamSessionPresence := streamPresence[pid] + + streamSessionPresence.add(sid, info) + + if newPresence { + return &common.PresenceEvent{ + Type: common.PresenceJoinType, + ID: pid, + Info: info, + }, nil + } + + return nil, nil +} + +func (b *Memory) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + b.presence.mu.Lock() + defer b.presence.mu.Unlock() + + if _, ok := b.presence.streams[stream]; !ok { + return nil, errors.New("stream not found") + } + + var pid string + + if ses, ok := b.presence.sessions[sid]; ok { + if id, ok := ses.streams[stream]; !ok { + return nil, errors.New("presence info not found") + } else { + pid = id + } + + delete(ses.streams, stream) + + if len(ses.streams) == 0 { + delete(b.presence.sessions, sid) + } + } + + streamPresence := b.presence.streams[stream] + + if _, ok := streamPresence[pid]; !ok { + return nil, errors.New("presence record not found") + } + + streamSessionPresence := streamPresence[pid] + + empty := streamSessionPresence.remove(sid) + + if empty { + delete(streamPresence, pid) + } + + if len(streamPresence) == 0 { + delete(b.presence.streams, stream) + } + + if empty { + return &common.PresenceEvent{ + Type: common.PresenceLeaveType, + ID: pid, + }, nil + } + + return nil, nil +} + +func (b *Memory) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { + options := NewPresenceInfoOptions() + for _, opt := range opts { + opt(options) + } + + b.presence.mu.RLock() + defer b.presence.mu.RUnlock() + + info := common.NewPresenceInfo() + + if _, ok := b.presence.streams[stream]; !ok { + return info, nil + } + + streamPresence := b.presence.streams[stream] + + info.Total = len(streamPresence) + + if options.ReturnRecords { + info.Records = make([]*common.PresenceEvent, 0, len(streamPresence)) + + for _, entry := range streamPresence { + info.Records = append(info.Records, &common.PresenceEvent{ + Info: entry.info, + ID: entry.id, + }) + } + } + + return info, nil +} + func (b *Memory) add(name string, data string) uint64 { b.streamsMu.Lock() @@ -457,4 +659,71 @@ func (b *Memory) expire() { b.expireSessions = b.expireSessions[i:] b.sessionsMu.Unlock() + + // presence expiration + b.expirePresence() +} + +func (b *Memory) expirePresence() { + b.presence.mu.Lock() + + now := time.Now().Unix() + toDelete := []string{} + + for sid, sp := range b.presence.sessions { + if sp.deadline > 0 && sp.deadline < now { + toDelete = append(toDelete, sid) + } + } + + leaveMessages := []common.StreamMessage{} + + for _, sid := range toDelete { + entry := b.presence.sessions[sid] + + for stream, pid := range entry.streams { + if _, ok := b.presence.streams[stream]; !ok { + continue + } + + if _, ok := b.presence.streams[stream][pid]; !ok { + continue + } + + streamSessionPresence := b.presence.streams[stream][pid] + + empty := streamSessionPresence.remove(sid) + + if empty { + delete(b.presence.streams[stream], pid) + + msg := &common.PresenceEvent{Type: common.PresenceLeaveType, ID: pid} + + leaveMessages = append(leaveMessages, common.StreamMessage{ + Stream: stream, + Data: string(utils.ToJSON(msg)), + Meta: &common.StreamMessageMetadata{ + BroadcastType: common.PresenceType, + Transient: true, + }, + }) + + if len(b.presence.streams[stream]) == 0 { + delete(b.presence.streams, stream) + } + } + } + + delete(b.presence.sessions, sid) + } + + b.presence.mu.Unlock() + + if b.broadcaster != nil { + // TODO: batch broadcast? + // FIXME: move broadcasts out of broker + for _, msg := range leaveMessages { + b.broadcaster.Broadcast(&msg) + } + } } diff --git a/broker/memory_test.go b/broker/memory_test.go index f0a32820..6d8bc6ee 100644 --- a/broker/memory_test.go +++ b/broker/memory_test.go @@ -1,9 +1,11 @@ package broker import ( + "slices" "testing" "time" + "github.com/anycable/anycable-go/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -149,3 +151,93 @@ func TestMemstream_filterByOffset(t *testing.T) { }) require.Error(t, err) } + +func TestMemory_Presence(t *testing.T) { + config := NewConfig() + + broker := NewMemoryBroker(nil, &config) + + ev, err := broker.PresenceAdd("a", "s1", "user_1", map[string]interface{}{"name": "John"}) + require.NoError(t, err) + + assert.Equal(t, "user_1", ev.ID) + assert.Equal(t, "join", ev.Type) + assert.Equal(t, map[string]interface{}{"name": "John"}, ev.Info) + + // Adding presence for the same session with different ID is illegal + ev, err = broker.PresenceAdd("a", "s1", "user_2", map[string]interface{}{"name": "Boo"}) + require.Error(t, err) + assert.Nil(t, ev) + + ev, err = broker.PresenceAdd("a", "s2", "user_1", map[string]interface{}{"name": "Jack"}) + require.NoError(t, err) + assert.Nil(t, ev) + + ev, err = broker.PresenceAdd("a", "s3", "user_2", map[string]interface{}{"name": "Alice"}) + require.NoError(t, err) + assert.Equal(t, "user_2", ev.ID) + + ev, err = broker.PresenceAdd("b", "s3", "user_2", map[string]interface{}{"name": "Alice"}) + require.NoError(t, err) + assert.Equal(t, "user_2", ev.ID) + + info, err := broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + assert.Equal(t, 2, len(info.Records)) + + // Make sure the latest info is returned + assert.Truef(t, slices.ContainsFunc(info.Records, func(r *common.PresenceEvent) bool { + return r.ID == "user_1" && (r.Info.(map[string]interface{})["name"] == "Jack") + }), "presence user with user_id and name:Jack not found: %s", info.Records) + + // Now let's check that leave works + ev, err = broker.PresenceRemove("a", "s1") + require.NoError(t, err) + assert.Nil(t, ev) + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + + ev, err = broker.PresenceRemove("a", "s2") + require.NoError(t, err) + assert.Equal(t, "user_1", ev.ID) + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 1, info.Total) +} + +func TestMemory_expirePresence(t *testing.T) { + config := NewConfig() + config.PresenceTTL = 1 + + broker := NewMemoryBroker(nil, &config) + + broker.PresenceAdd("a", "s1", "user_1", "john") // nolint: errcheck + broker.PresenceAdd("a", "s2", "user_2", "kate") // nolint: errcheck + + info, err := broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 2, info.Total) + + broker.FinishPresence("s1") // nolint: errcheck + broker.FinishPresence("s2") // nolint: errcheck + + time.Sleep(2 * time.Second) + + broker.PresenceAdd("a", "s3", "user_1", "john") // nolint: errcheck + + broker.expire() + + info, err = broker.PresenceInfo("a") + require.NoError(t, err) + + assert.Equal(t, 1, info.Total) + assert.Equal(t, "user_1", info.Records[0].ID) +} diff --git a/broker/nats.go b/broker/nats.go index 46153ae5..5e5c8019 100644 --- a/broker/nats.go +++ b/broker/nats.go @@ -482,6 +482,22 @@ func (n *NATS) Reset() error { return nil } +func (n *NATS) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") +} + +func (n *NATS) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + return nil, errors.New("presence not supported") +} + +func (n *NATS) PresenceInfo(stream string, opts ...PresenceInfoOption) (*common.PresenceInfo, error) { + return nil, errors.New("presence not supported") +} + +func (n *NATS) FinishPresence(sid string) error { + return nil +} + func (n *NATS) add(stream string, data string) (uint64, error) { err := n.ensureStreamExists(stream) diff --git a/cli/options.go b/cli/options.go index 38b0ea71..e15c113b 100644 --- a/cli/options.go +++ b/cli/options.go @@ -637,6 +637,12 @@ func brokerCLIFlags(c *config.Config) []cli.Flag { Value: c.Broker.SessionsTTL, Destination: &c.Broker.SessionsTTL, }, + &cli.Int64Flag{ + Name: "presence_ttl", + Usage: "TTL for presence information (seconds)", + Value: c.Broker.PresenceTTL, + Destination: &c.Broker.PresenceTTL, + }, }) } diff --git a/common/common.go b/common/common.go index 671bbd98..abbdccea 100644 --- a/common/common.go +++ b/common/common.go @@ -61,10 +61,16 @@ const ( RejectedType = "reject_subscription" // Not supported by Action Cable currently UnsubscribedType = "unsubscribed" + ErrorType = "error" HistoryConfirmedType = "confirm_history" HistoryRejectedType = "reject_history" + PresenceInfoType = "info" + PresenceJoinType = "join" + PresenceLeaveType = "leave" + PresenceType = "presence" + WhisperType = "whisper" ) @@ -79,7 +85,8 @@ const ( // Reserver state fields const ( - WHISPER_STREAM_STATE = "$w" + WHISPER_STREAM_STATE = "$w" + PRESENCE_STREAM_STATE = "$p" ) // SessionEnv represents the underlying HTTP connection data: @@ -217,6 +224,25 @@ func (c *ConnectResult) ToCallResult() *CallResult { return &res } +type PresenceEvent struct { + Type string `json:"type,omitempty"` + Info interface{} `json:"info,omitempty"` + ID string `json:"id"` +} + +type PresenceInfo struct { + // Type is always "presence" + Type string `json:"type,omitempty"` + // Total number of present clients (uniq) + Total int `json:"total"` + // Presence records + Records []*PresenceEvent `json:"records,omitempty"` +} + +func NewPresenceInfo() *PresenceInfo { + return &PresenceInfo{Type: PresenceInfoType} +} + // CommandResult is a result of performing controller action, // which contains informations about streams to subscribe, // messages to sent and broadcast. @@ -301,6 +327,7 @@ type Message struct { Identifier string `json:"identifier"` Data interface{} `json:"data,omitempty"` History HistoryRequest `json:"history,omitempty"` + Presence *PresenceEvent `json:"presence,omitempty"` } func (m *Message) LogValue() slog.Value { @@ -464,17 +491,18 @@ func NewDisconnectMessage(reason string, reconnect bool) *DisconnectMessage { // Reply represents an outgoing client message type Reply struct { - Type string `json:"type,omitempty"` - Identifier string `json:"identifier,omitempty"` - Message interface{} `json:"message,omitempty"` - Reason string `json:"reason,omitempty"` - Reconnect bool `json:"reconnect,omitempty"` - StreamID string `json:"stream_id,omitempty"` - Epoch string `json:"epoch,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Sid string `json:"sid,omitempty"` - Restored bool `json:"restored,omitempty"` - RestoredIDs []string `json:"restored_ids,omitempty"` + Type string `json:"type,omitempty"` + Identifier string `json:"identifier,omitempty"` + Message interface{} `json:"message,omitempty"` + Presence *PresenceEvent `json:"presence,omitempty"` + Reason string `json:"reason,omitempty"` + Reconnect bool `json:"reconnect,omitempty"` + StreamID string `json:"stream_id,omitempty"` + Epoch string `json:"epoch,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Sid string `json:"sid,omitempty"` + Restored bool `json:"restored,omitempty"` + RestoredIDs []string `json:"restored_ids,omitempty"` } func (r *Reply) LogValue() slog.Value { diff --git a/docs/.trieve.yml b/docs/.trieve.yml index 12565e0d..a5815d9e 100644 --- a/docs/.trieve.yml +++ b/docs/.trieve.yml @@ -32,6 +32,7 @@ pages: - "./broker.md" - "./configuration.md" - "./embedded_nats.md" + - "./presence.md" - "./getting_started.md" - "./health_checking.md" - "./instrumentation.md" diff --git a/docs/Readme.md b/docs/Readme.md index dcf6e8b5..6909c2dc 100644 --- a/docs/Readme.md +++ b/docs/Readme.md @@ -12,5 +12,6 @@ * [Binary formats](binary_formats.md) * [JWT identification](jwt_identification.md) * [Signed streams](signed_streams.md) +* [Presence](presence.md) * [Embedded NATS](embedded_nats.md) * [Using as a library](library.md) diff --git a/docs/broker.md b/docs/broker.md index 3cbe2a88..e1ffe820 100644 --- a/docs/broker.md +++ b/docs/broker.md @@ -1,11 +1,12 @@ # Broker deep dive -Broker is a component of AnyCable-Go responsible for keeping streams and sessions information in a cache-like storage. It drives the [Reliable Streams](./reliable_streams.md) feature. +Broker is a component of AnyCable-Go responsible for keeping streams, sessions and presence information in a cache-like storage. It drives the [Reliable Streams](./reliable_streams.md) and [Presence](./presence.md) features. Broker implements features that can be characterized as _hot cache utilities_: - Handling incoming broadcast messages and storing them in a cache—that could help clients to receive missing broadcasts (triggered while the client was offline, for example). - Persisting client states—to make it possible to restore on re-connection (by providing a _session id_ of the previous connection). +- Keeping per-channel presence information. ## Client-server communication diff --git a/docs/presence.md b/docs/presence.md new file mode 100644 index 00000000..06c99305 --- /dev/null +++ b/docs/presence.md @@ -0,0 +1,74 @@ +# Presence tracking + +AnyCable comes with a built-in presence tracking support for your real-time applications. No need to write custom code and deal with storage mechanisms to know who's online in your channels. + +## Overview + +AnyCable presence allows channel subscribers to share their presence information with other clients and track the changes in the channel's presence set. Presence data can be used to display a list of online users, track user activity, etc. + +## Quick start + +Presence is a part of the [broker](./broker.md) component, so you must enable it either via the `broker` configuration preset or manually: + +```sh +$ anycable-go --presets=broker + +# or + +$ anycable-go --broker=memory +``` + +> 🚧 Currently, presence tracking is only supported by the memory broker. + +Now, you can use the presence API in your application. For example, using [AnyCable JS client](https://github.com/anycable/anycable-client): + +```js +import { createCable } from '@anycable/web' +// or for non-web projects +// import { createCable } from '@anycable/core' + +const cable = createCable({protocol: 'actioncable-v1-ext-json'}) + +const channel = cable.streamFrom('room/42'); + +// join the channel's presence set +channel.presence.join(user.id, { name: user.name }) + +// get the current presence state +const presence = await chatChannel.presence.info() + +// subscribe to presence events +channel.on("presence", (event) => { + const { type, info, id } = event + + if (type === "join") { + console.log(`${info.name} joined the channel`) + } else if (type === "leave") { + console.log(`${id} left the channel`) + } +}) +``` + +## Presence lifecycle + +Clients join the presence set explicitly by performing the `presence` command. The `join` event is sent to all subscribers (including the initiator) with the presence information, but only if the **presence ID** (provided by the client) hasn't been registered yet. Thus, multiple sessions with the same ID are treated as a single presence record. + +Clients may explicitly leave the presence set by performing the `leave` command or by unsubscribing from the channel. The `leave` event is sent to all subscribers only if no other sessions with the same ID are left in the presence set. + +When a client disconnects without explicitly leaving or unsubscribing the channel, it's present information stays in the set for a short period of time. That prevents the burst of `join` / `leave` events when the client reconnects frequently. + +## Configuration + +You can configure the presence expiration time (for disconnected clients) via the `--presence_ttl` option. The default value is 15 seconds. + +## Presence for channels + +> 🚧 Presence for channels is to be implemented. You can only use presence with [signed streams](./signed_streams.md) for now. + +## Presence API + +> 🚧 Presence REST API is to be implemented. You can only use the presence API via the WebSocket connection. + +## Presence webhooks + +> 🚧 Presence webhooks are to be implemented, too. diff --git a/forspell.dict b/forspell.dict index 6b882f6b..366e4af8 100644 --- a/forspell.dict +++ b/forspell.dict @@ -67,3 +67,4 @@ reconnection norpc noauth jid +webhooks diff --git a/mocks/Broker.go b/mocks/Broker.go index 6ce183e4..422a51bf 100644 --- a/mocks/Broker.go +++ b/mocks/Broker.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.50.0. DO NOT EDIT. package mocks @@ -16,10 +16,14 @@ type Broker struct { mock.Mock } -// Announce provides a mock function with given fields: +// Announce provides a mock function with no fields func (_m *Broker) Announce() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Announce") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -34,6 +38,10 @@ func (_m *Broker) Announce() string { func (_m *Broker) CommitSession(sid string, session broker.Cacheable) error { ret := _m.Called(sid, session) + if len(ret) == 0 { + panic("no return value specified for CommitSession") + } + var r0 error if rf, ok := ret.Get(0).(func(string, broker.Cacheable) error); ok { r0 = rf(sid, session) @@ -44,10 +52,32 @@ func (_m *Broker) CommitSession(sid string, session broker.Cacheable) error { return r0 } +// FinishPresence provides a mock function with given fields: sid +func (_m *Broker) FinishPresence(sid string) error { + ret := _m.Called(sid) + + if len(ret) == 0 { + panic("no return value specified for FinishPresence") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(sid) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // FinishSession provides a mock function with given fields: sid func (_m *Broker) FinishSession(sid string) error { ret := _m.Called(sid) + if len(ret) == 0 { + panic("no return value specified for FinishSession") + } + var r0 error if rf, ok := ret.Get(0).(func(string) error); ok { r0 = rf(sid) @@ -72,6 +102,10 @@ func (_m *Broker) HandleCommand(msg *common.RemoteCommandMessage) { func (_m *Broker) HistoryFrom(stream string, epoch string, offset uint64) ([]common.StreamMessage, error) { ret := _m.Called(stream, epoch, offset) + if len(ret) == 0 { + panic("no return value specified for HistoryFrom") + } + var r0 []common.StreamMessage var r1 error if rf, ok := ret.Get(0).(func(string, string, uint64) ([]common.StreamMessage, error)); ok { @@ -98,6 +132,10 @@ func (_m *Broker) HistoryFrom(stream string, epoch string, offset uint64) ([]com func (_m *Broker) HistorySince(stream string, ts int64) ([]common.StreamMessage, error) { ret := _m.Called(stream, ts) + if len(ret) == 0 { + panic("no return value specified for HistorySince") + } + var r0 []common.StreamMessage var r1 error if rf, ok := ret.Get(0).(func(string, int64) ([]common.StreamMessage, error)); ok { @@ -120,10 +158,111 @@ func (_m *Broker) HistorySince(stream string, ts int64) ([]common.StreamMessage, return r0, r1 } +// PresenceAdd provides a mock function with given fields: stream, sid, pid, info +func (_m *Broker) PresenceAdd(stream string, sid string, pid string, info interface{}) (*common.PresenceEvent, error) { + ret := _m.Called(stream, sid, pid, info) + + if len(ret) == 0 { + panic("no return value specified for PresenceAdd") + } + + var r0 *common.PresenceEvent + var r1 error + if rf, ok := ret.Get(0).(func(string, string, string, interface{}) (*common.PresenceEvent, error)); ok { + return rf(stream, sid, pid, info) + } + if rf, ok := ret.Get(0).(func(string, string, string, interface{}) *common.PresenceEvent); ok { + r0 = rf(stream, sid, pid, info) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.PresenceEvent) + } + } + + if rf, ok := ret.Get(1).(func(string, string, string, interface{}) error); ok { + r1 = rf(stream, sid, pid, info) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PresenceInfo provides a mock function with given fields: stream, opts +func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) (*common.PresenceInfo, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, stream) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for PresenceInfo") + } + + var r0 *common.PresenceInfo + var r1 error + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) (*common.PresenceInfo, error)); ok { + return rf(stream, opts...) + } + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) *common.PresenceInfo); ok { + r0 = rf(stream, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.PresenceInfo) + } + } + + if rf, ok := ret.Get(1).(func(string, ...broker.PresenceInfoOption) error); ok { + r1 = rf(stream, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PresenceRemove provides a mock function with given fields: stream, sid +func (_m *Broker) PresenceRemove(stream string, sid string) (*common.PresenceEvent, error) { + ret := _m.Called(stream, sid) + + if len(ret) == 0 { + panic("no return value specified for PresenceRemove") + } + + var r0 *common.PresenceEvent + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*common.PresenceEvent, error)); ok { + return rf(stream, sid) + } + if rf, ok := ret.Get(0).(func(string, string) *common.PresenceEvent); ok { + r0 = rf(stream, sid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*common.PresenceEvent) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(stream, sid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RestoreSession provides a mock function with given fields: from func (_m *Broker) RestoreSession(from string) ([]byte, error) { ret := _m.Called(from) + if len(ret) == 0 { + panic("no return value specified for RestoreSession") + } + var r0 []byte var r1 error if rf, ok := ret.Get(0).(func(string) ([]byte, error)); ok { @@ -150,6 +289,10 @@ func (_m *Broker) RestoreSession(from string) ([]byte, error) { func (_m *Broker) Shutdown(ctx context.Context) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for Shutdown") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(ctx) @@ -164,6 +307,10 @@ func (_m *Broker) Shutdown(ctx context.Context) error { func (_m *Broker) Start(done chan error) error { ret := _m.Called(done) + if len(ret) == 0 { + panic("no return value specified for Start") + } + var r0 error if rf, ok := ret.Get(0).(func(chan error) error); ok { r0 = rf(done) @@ -178,6 +325,10 @@ func (_m *Broker) Start(done chan error) error { func (_m *Broker) Subscribe(stream string) string { ret := _m.Called(stream) + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + var r0 string if rf, ok := ret.Get(0).(func(string) string); ok { r0 = rf(stream) @@ -192,6 +343,10 @@ func (_m *Broker) Subscribe(stream string) string { func (_m *Broker) Unsubscribe(stream string) string { ret := _m.Called(stream) + if len(ret) == 0 { + panic("no return value specified for Unsubscribe") + } + var r0 string if rf, ok := ret.Get(0).(func(string) string); ok { r0 = rf(stream) @@ -202,13 +357,12 @@ func (_m *Broker) Unsubscribe(stream string) string { return r0 } -type mockConstructorTestingTNewBroker interface { +// NewBroker creates a new instance of Broker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewBroker(t interface { mock.TestingT Cleanup(func()) -} - -// NewBroker creates a new instance of Broker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewBroker(t mockConstructorTestingTNewBroker) *Broker { +}) *Broker { mock := &Broker{} mock.Mock.Test(t) diff --git a/node/broker_integration_test.go b/node/broker_integration_test.go index 95f70b1e..2a4c1675 100644 --- a/node/broker_integration_test.go +++ b/node/broker_integration_test.go @@ -393,6 +393,173 @@ func sharedIntegrationHistory(t *testing.T, node *Node, controller *mocks.Contro }) } +// A test to verify the presence flow. +// +// SETUP: +// - Two sessions are created (authenticated and subscribed with presence stream). +// +// TEST 1 — join/leave events and info: +// - First session joins the channel. +// - Both sessions receive join event. +// - Second session requests presence info. +// - Second session joins the channel. +// - First session leaves the channel. +// - Both sessions receive leave event. +// - Second session requests presence info. +// +// TEST 2 — presence expiration: +// - Both sessions joins the channel. +// - Both sessions receive join event. +// - First session disconnects. +// - Wait for expiration. +// - Second session receives leave event. +// - Second session requests presence info. +func TestIntegrationPresence_Memory(t *testing.T) { + node, controller := setupIntegrationNode() + + bconf := broker.NewConfig() + bconf.PresenceTTL = 2 + + subscriber := pubsub.NewLegacySubscriber(node) + + br := broker.NewMemoryBroker(subscriber, &bconf) + node.SetBroker(br) + + require.NoError(t, br.Start(nil)) + + go node.Start() // nolint:errcheck + defer node.Shutdown(context.Background()) // nolint:errcheck + + sharedIntegrationPresence(t, node, controller) +} + +func sharedIntegrationPresence(t *testing.T, node *Node, controller *mocks.Controller) { + controller. + On("Subscribe", "sasha", mock.Anything, "sasha", "chat_1"). + Return(&common.CommandResult{ + Status: common.SUCCESS, + Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, + Streams: []string{"presence_1", "messages_1"}, + IState: map[string]string{common.PRESENCE_STREAM_STATE: "presence_1"}, + }, nil) + controller. + On("Subscribe", "mia", mock.Anything, "mia", "chat_1"). + Return(&common.CommandResult{ + Status: common.SUCCESS, + Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, + Streams: []string{"presence_1", "messages_1"}, + IState: map[string]string{common.PRESENCE_STREAM_STATE: "presence_1"}, + }, nil) + controller. + On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(&common.CommandResult{ + Status: common.SUCCESS, + }, nil) + + setupSessions := func() (*Session, *Session, func()) { + session := requireAuthenticatedSession(t, node, "sasha") + _, err := node.Subscribe( + session, + &common.Message{ + Identifier: "chat_1", + Command: "subscribe", + }) + require.NoError(t, err) + assertReceive(t, session, `{"type":"confirm","identifier":"chat_1"}`) + + session2 := requireAuthenticatedSession(t, node, "mia") + _, err = node.Subscribe( + session2, + &common.Message{ + Identifier: "chat_1", + Command: "subscribe", + }) + require.NoError(t, err) + assertReceive(t, session2, `{"type":"confirm","identifier":"chat_1"}`) + + return session, session2, func() { + node.Unsubscribe(session, &common.Message{Identifier: "chat_1", Command: "unsubscribe"}) // nolint:errcheck + node.Unsubscribe(session2, &common.Message{Identifier: "chat_1", Command: "unsubscribe"}) // nolint:errcheck + } + } + + t.Run("Join and leave", func(t *testing.T) { + sasha, mia, cleanup := setupSessions() + defer cleanup() + + err := node.PresenceJoin(sasha, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "42", Info: map[string]interface{}{"name": "Sasha"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"42","info":{"name":"Sasha"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"42","info":{"name":"Sasha"},"type":"join"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"42","info":{"name":"Sasha"}}]}}`) + + err = node.PresenceJoin(mia, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "13", Info: map[string]interface{}{"name": "Mia"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"13","info":{"name":"Mia"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"13","info":{"name":"Mia"},"type":"join"}}`) + + err = node.Presence(sasha, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + msg := assertReceiveMsg(t, sasha) + assert.Equal(t, "presence", msg["type"]) + assert.Equal(t, "chat_1", msg["identifier"]) + payload := msg["message"].(map[string]interface{}) + assert.Equal(t, "info", payload["type"]) + assert.Equal(t, float64(2), payload["total"]) + records := payload["records"].([]interface{}) + assert.Len(t, records, 2) + assert.Contains(t, records, map[string]interface{}{"id": "42", "info": map[string]interface{}{"name": "Sasha"}}) + assert.Contains(t, records, map[string]interface{}{"id": "13", "info": map[string]interface{}{"name": "Mia"}}) + + err = node.PresenceLeave(sasha, &common.Message{Identifier: "chat_1"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"42","type":"leave"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"42","type":"leave"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"13","info":{"name":"Mia"}}]}}`) + }) + + t.Run("Presence expiration", func(t *testing.T) { + sasha, mia, cleanup := setupSessions() + defer cleanup() + + err := node.PresenceJoin(sasha, &common.Message{Identifier: "chat_1", Presence: &common.PresenceEvent{ID: "142", Info: map[string]interface{}{"name": "Rickie"}}}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"142","info":{"name":"Rickie"},"type":"join"}}`) + assertReceive(t, sasha, `{"type":"presence","identifier":"chat_1","message":{"id":"142","info":{"name":"Rickie"},"type":"join"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":1,"records":[{"id":"142","info":{"name":"Rickie"}}]}}`) + + err = node.Disconnect(sasha) + require.NoError(t, err) + + // Wait for expiration to happen + time.Sleep(4 * time.Second) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"id":"142","type":"leave"}}`) + + err = node.Presence(mia, &common.Message{Identifier: "chat_1", Command: "presence"}) + require.NoError(t, err) + + assertReceive(t, mia, `{"type":"presence","identifier":"chat_1","message":{"type":"info","total":0}}`) + }) +} + func setupIntegrationNode() (*Node, *mocks.Controller) { config := NewConfig() config.HubGopoolSize = 2 @@ -419,16 +586,32 @@ func requireReceive(t *testing.T, s *Session, expected string) { } func assertReceive(t *testing.T, s *Session, expected string) { - msg, err := s.conn.Read() + parsedMessage := assertReceiveMsg(t, s) + + var expectedMessage map[string]interface{} + + err := json.Unmarshal([]byte(expected), &expectedMessage) require.NoError(t, err) assert.Equal( t, - expected, - string(msg), + expectedMessage, + parsedMessage, ) } +func assertReceiveMsg(t *testing.T, s *Session) map[string]interface{} { + msg, err := s.conn.Read() + require.NoError(t, err) + + var parsedMessage map[string]interface{} + + err = json.Unmarshal(msg, &parsedMessage) + require.NoError(t, err) + + return parsedMessage +} + func requireAuthenticatedSession(t *testing.T, node *Node, sid string) *Session { session := NewMockSessionWithEnv(sid, node, "ws://test.anycable.io/cable", nil) diff --git a/node/node.go b/node/node.go index a1f0a7b3..332fe19b 100644 --- a/node/node.go +++ b/node/node.go @@ -2,6 +2,7 @@ package node import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -175,6 +176,12 @@ func (n *Node) HandleCommand(s *Session, msg *common.Message) (err error) { _, err = n.Perform(s, msg) case "history": err = n.History(s, msg) + case "presence": + err = n.Presence(s, msg) + case "join": + err = n.PresenceJoin(s, msg) + case "leave": + err = n.PresenceLeave(s, msg) case "whisper": err = n.Whisper(s, msg) default: @@ -471,8 +478,12 @@ func (n *Node) Subscribe(s *Session, msg *common.Message) (*common.CommandResult if err := n.History(s, msg); err != nil { s.Log.Warn("couldn't retrieve history", "identifier", msg.Identifier, "error", err) } + } - return res, nil + if msg.Presence != nil { + if err := n.handlePresenceReply(s, msg.Identifier, common.PresenceJoinType, msg.Presence); err != nil { + s.Log.Warn("couldn't process presence join", "identifier", msg.Identifier, "error", err) + } } } @@ -492,6 +503,8 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu s.Log.Debug("controller unsubscribe", "response", res, "err", err) + presenceStream := "" + if err != nil { if res == nil || res.Status == common.ERROR { return nil, errorx.Decorate(err, "failed to unsubscribe from %s", msg.Identifier) @@ -500,6 +513,8 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu // Make sure to remove all streams subscriptions res.StopAllStreams = true + presenceStream = s.env.GetChannelStateField(msg.Identifier, common.PRESENCE_STREAM_STATE) + s.env.RemoveChannelState(msg.Identifier) s.subscriptions.RemoveChannel(msg.Identifier) @@ -518,6 +533,13 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResu } } + // Make sure presence is removed on explicit unsubscribe + if presenceStream != "" { + if _, err := n.broker.PresenceRemove(presenceStream, s.GetID()); err != nil { + s.Log.Error("failed to remove presence", "error", err) + } + } + return res, nil } @@ -670,6 +692,89 @@ func (n *Node) Whisper(s *Session, msg *common.Message) error { return nil } +// Presence returns the presence information for the specified identifier +func (n *Node) Presence(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + // Check that the presence stream is configured (thus, the feature is enabled) + env := s.GetEnv() + if env == nil { + s.smu.Unlock() + return errors.New("session environment is missing") + } + + stream := env.GetChannelStateField(msg.Identifier, common.PRESENCE_STREAM_STATE) + + if stream == "" { + s.smu.Unlock() + return fmt.Errorf("presence stream not found for identifier: %s", msg.Identifier) + } + + s.smu.Unlock() + + options := broker.NewPresenceInfoOptions() + + if msg.Data != nil { + buf, err := json.Marshal(&msg.Data) + if err != nil { + json.Unmarshal(buf, &options) // nolint:errcheck + } + } + + info, err := n.broker.PresenceInfo(stream, broker.WithPresenceInfoOptions(options)) + + if err != nil { + s.Send(&common.Reply{ + Type: common.PresenceType, + Identifier: msg.Identifier, + Message: &common.PresenceInfo{Type: common.ErrorType}, + }) + + return err + } + + s.Send(&common.Reply{ + Type: common.PresenceType, + Identifier: msg.Identifier, + Message: info, + }) + + return nil +} + +// PresenceJoin adds the session to the presence state for the specified identifier +func (n *Node) PresenceJoin(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + s.smu.Unlock() + + return n.handlePresenceReply(s, msg.Identifier, common.PresenceJoinType, msg.Presence) +} + +// PresenceLeave removes the session to the presence state for the specified identifier +func (n *Node) PresenceLeave(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + s.smu.Unlock() + + return n.handlePresenceReply(s, msg.Identifier, common.PresenceLeaveType, msg.Presence) +} + // Broadcast message to stream (locally) func (n *Node) Broadcast(msg *common.StreamMessage) { n.metrics.CounterIncrement(metricsBroadcastMsg) @@ -701,6 +806,8 @@ func (n *Node) Disconnect(s *Session) error { n.broker.FinishSession(s.GetID()) // nolint:errcheck } + n.broker.FinishPresence(s.GetID()) // nolint:errcheck + if n.IsShuttingDown() { // Make sure session is removed from hub, so we don't try to send // broadcast messages to them @@ -820,6 +927,10 @@ func (n *Node) handleCommandReply(s *Session, msg *common.Message, reply *common } isConnectionDirty := n.handleCallReply(s, reply.ToCallResult()) + + // TODO: RPC-driven presence + // n.handlePresenceReply(s, reply.Presence) + return isDirty || isConnectionDirty } @@ -847,6 +958,49 @@ func (n *Node) handleCallReply(s *Session, reply *common.CallResult) bool { return isDirty } +func (n *Node) handlePresenceReply(s *Session, identifier string, event string, presence *common.PresenceEvent) error { + // Check that the presence stream is configured (thus, the feature is enabled) + env := s.GetEnv() + if env == nil { + return errors.New("session environment is missing") + } + + stream := env.GetChannelStateField(identifier, common.PRESENCE_STREAM_STATE) + + if stream == "" { + return fmt.Errorf("presence stream not found for identifier: %s", identifier) + } + + sid := s.GetID() + + var err error + var msg *common.PresenceEvent + + if event == common.PresenceJoinType { // nolint:gocritic + if presence == nil { + return errors.New("presence data is missing") + } + msg, err = n.broker.PresenceAdd(stream, sid, presence.ID, presence.Info) + } else if event == common.PresenceLeaveType { + msg, err = n.broker.PresenceRemove(stream, sid) + } else { + return fmt.Errorf("unknown presence event: %s", event) + } + + if msg != nil { + n.Broadcast(&common.StreamMessage{ + Stream: stream, + Data: string(utils.ToJSON(msg)), + Meta: &common.StreamMessageMetadata{ + BroadcastType: common.PresenceType, + Transient: true, + }, + }) + } + + return err +} + // disconnectScheduler controls how quickly to disconnect sessions type disconnectScheduler interface { // This method is called when a session is ready to be disconnected, diff --git a/streams/config.go b/streams/config.go index aff615bf..db695ef1 100644 --- a/streams/config.go +++ b/streams/config.go @@ -17,6 +17,9 @@ type Config struct { // Whisper determines if whispering is enabled for pub/sub streams Whisper bool `toml:"whisper"` + // Presence determines if presence is enabled for pub/sub streams + Presence bool `toml:"presence"` + // PubSubChannel is the channel name used for direct pub/sub PubSubChannel string `toml:"pubsub_channel"` @@ -80,6 +83,13 @@ func (c Config) ToToml() string { result.WriteString("# whisper = true\n") } + result.WriteString("# Enable presence support for pub/sub streams\n") + if c.Presence { + result.WriteString("presence = true\n") + } else { + result.WriteString("# presence = true\n") + } + result.WriteString("# Name of the channel used for pub/sub\n") result.WriteString(fmt.Sprintf("pubsub_channel = \"%s\"\n", c.PubSubChannel)) diff --git a/streams/controller.go b/streams/controller.go index 72f09139..eb16182d 100644 --- a/streams/controller.go +++ b/streams/controller.go @@ -15,7 +15,8 @@ type SubscribeRequest struct { StreamName string `json:"stream_name"` SignedStreamName string `json:"signed_stream_name"` - whisper bool + whisper bool + presence bool } func (r *SubscribeRequest) IsPresent() bool { @@ -111,7 +112,15 @@ func (c *Controller) Subscribe(sid string, env *common.SessionEnv, ids string, i var state map[string]string if request.whisper { - state = map[string]string{common.WHISPER_STREAM_STATE: stream} + state = make(map[string]string) + state[common.WHISPER_STREAM_STATE] = stream + } + + if request.presence { + if state == nil { + state = make(map[string]string) + } + state[common.PRESENCE_STREAM_STATE] = stream } return &common.CommandResult{ @@ -144,6 +153,7 @@ func NewStreamsController(conf *Config, l *slog.Logger) *Controller { key := conf.Secret allowPublic := conf.Public whispers := conf.Whisper + presence := conf.Presence resolver := func(identifier string) (*SubscribeRequest, error) { var request SubscribeRequest @@ -160,6 +170,10 @@ func NewStreamsController(conf *Config, l *slog.Logger) *Controller { request.whisper = true } + if presence || (request.StreamName != "") { + request.presence = true + } + return &request, nil }