diff --git a/relay/schema.go b/relay/schema.go index 354d322..664d18b 100644 --- a/relay/schema.go +++ b/relay/schema.go @@ -3,6 +3,7 @@ package relay import ( "encoding/json" "strings" + "sync" "go.uber.org/zap/zapcore" ) @@ -81,30 +82,43 @@ func cachedMessageKey(topic string) string { } // TopicClientSet stores topic -> clients relationship -type TopicClientSet map[string]map[*client]struct{} +type TopicClientSet struct { + *sync.RWMutex + Data map[string]map[*client]struct{} +} -func NewTopicClientSet() TopicClientSet { - return make(map[string]map[*client]struct{}) +func NewTopicClientSet() *TopicClientSet { + return &TopicClientSet{ + RWMutex: &sync.RWMutex{}, + Data: map[string]map[*client]struct{}{}, + } } -func (ts TopicClientSet) Get(topic string) map[*client]struct{} { - return ts[topic] +func (ts *TopicClientSet) Get(topic string) map[*client]struct{} { + ts.RLock() + defer ts.RUnlock() + return ts.Data[topic] } -func (ts TopicClientSet) Set(topic string, c *client) { - if _, ok := ts[topic]; !ok { - ts[topic] = make(map[*client]struct{}) +func (ts *TopicClientSet) Set(topic string, c *client) { + ts.Lock() + defer ts.Unlock() + if _, ok := ts.Data[topic]; !ok { + ts.Data[topic] = make(map[*client]struct{}) } - ts[topic][c] = struct{}{} + ts.Data[topic][c] = struct{}{} } // GetTopicsByClient returns the topics associated with the specified client, // meanwhile, remove the client from these topics if `clear` is true // returns the topics the client has associated with -func (ts TopicClientSet) GetTopicsByClient(c *client, clear bool) []string { +func (ts *TopicClientSet) GetTopicsByClient(c *client, clear bool) []string { + // Write lock in a read func because we may remove the client from the topics + ts.Lock() + defer ts.Unlock() topics := []string{} - for topic, set := range ts { + for topic, set := range ts.Data { if _, ok := set[c]; ok { topics = append(topics, topic) } @@ -115,26 +129,46 @@ func (ts TopicClientSet) GetTopicsByClient(c *client, clear bool) []string { return topics } -func (ts TopicClientSet) Unset(topic string, c *client) { - delete(ts[topic], c) +func (ts *TopicClientSet) Unset(topic string, c *client) { + ts.Lock() + defer ts.Unlock() + delete(ts.Data[topic], c) +} + +func (ts *TopicClientSet) Len(topic string) int { + ts.RLock() + defer ts.RUnlock() + return len(ts.Data[topic]) } -func (ts TopicClientSet) Len(topic string) int { - return len(ts[topic]) +func (ts *TopicClientSet) Clear(topic string) { + ts.Lock() + defer ts.Unlock() + delete(ts.Data, topic) } -func (ts TopicClientSet) Clear(topic string) { - delete(ts, topic) +type TopicSet struct { + *sync.RWMutex + Data map[string]struct{} } -type TopicSet map[string]struct{} +func NewTopicSet() *TopicSet { + return &TopicSet{ + RWMutex: &sync.RWMutex{}, + Data: map[string]struct{}{}, + } +} -func NewTopicSet() TopicSet { - return make(map[string]struct{}) +func (tm TopicSet) Set(topic string) { + tm.Lock() + defer tm.Unlock() + tm.Data[topic] = struct{}{} } func (tm TopicSet) MarshalLogArray(encoder zapcore.ArrayEncoder) error { - for topic := range tm { + tm.Lock() + defer tm.Unlock() + for topic := range tm.Data { encoder.AppendString(topic) } return nil diff --git a/relay/wsconn.go b/relay/wsconn.go index 57a78ae..a056e02 100644 --- a/relay/wsconn.go +++ b/relay/wsconn.go @@ -22,8 +22,8 @@ type client struct { active bool // heartbeat related role RoleType // dapp or wallet session string // session id - pubTopics TopicSet - subTopics TopicSet + pubTopics *TopicSet + subTopics *TopicSet sendbuf chan SocketMessage // send buffer ping chan struct{} diff --git a/relay/wsserver.go b/relay/wsserver.go index ace6226..1df8d54 100644 --- a/relay/wsserver.go +++ b/relay/wsserver.go @@ -29,8 +29,8 @@ type WsServer struct { redisConn *redis.Client redisSubConn *redis.PubSub - publishers TopicClientSet - subscribers TopicClientSet + publishers *TopicClientSet + subscribers *TopicClientSet localCh chan SocketMessage // for handling message of local clients } @@ -109,12 +109,12 @@ func (ws *WsServer) Run() { switch message.Type { case Pub: // do not modify wsserver's local variable in seperate goroutine - message.client.pubTopics[message.Topic] = struct{}{} + message.client.pubTopics.Set(message.Topic) ws.publishers.Set(message.Topic, message.client) go ws.pubMessage(message) log.Info("local message", zap.Any("client", message.client), zap.Any("message", message)) case Sub: - message.client.subTopics[message.Topic] = struct{}{} + message.client.subTopics.Set(message.Topic) ws.subscribers.Set(message.Topic, message.client) go ws.subMessage(message) log.Info("local message", zap.Any("client", message.client), zap.Any("message", message))