Skip to content

Commit

Permalink
fix: 消息接收器并发导致的异常
Browse files Browse the repository at this point in the history
  • Loading branch information
rehiy committed Mar 9, 2024
1 parent a544a21 commit ddce2fb
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 35 deletions.
2 changes: 1 addition & 1 deletion httpd/wcfrest/receiver_url.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

var urlReceiverList = map[string]string{}

func urlReciever(url string) wcferry.MsgCallback {
func urlReciever(url string) wcferry.MsgConsumer {

return func(msg *wcferry.WxMsg) {
ret := wcferry.ParseWxMsg(msg)
Expand Down
2 changes: 1 addition & 1 deletion httpd/wcfrest/receiver_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/opentdp/wechat-rest/wcferry"
)

func socketReceiver(ws *websocket.Conn) wcferry.MsgCallback {
func socketReceiver(ws *websocket.Conn) wcferry.MsgConsumer {

mu := sync.Mutex{}

Expand Down
8 changes: 4 additions & 4 deletions wcferry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ func (c *Client) Connect() error {

// 启动消息接收器
// param pyq bool 是否接收朋友圈消息
// param cb MsgCallback 消息回调函数,可选参数
// param cb MsgConsumer 消息回调函数,可选参数
// return string 接收器唯一标识
func (c *Client) EnrollReceiver(pyq bool, cb MsgCallback) (string, error) {
if c.MsgClient.callbacks == nil {
func (c *Client) EnrollReceiver(pyq bool, cb MsgConsumer) (string, error) {
if c.MsgClient.consumer == nil {
if c.CmdClient.EnableMsgReciver(true) != 0 {
return "", errors.New("failed to enable msg server")
}
Expand All @@ -68,7 +68,7 @@ func (c *Client) EnrollReceiver(pyq bool, cb MsgCallback) (string, error) {
// return error 错误信息
func (c *Client) DisableReceiver(ks ...string) error {
err := c.MsgClient.Destroy(ks...)
if c.MsgClient.callbacks == nil {
if c.MsgClient.consumer == nil {
if c.CmdClient.DisableMsgReciver() != 0 {
return errors.New("failed to disable msg server")
}
Expand Down
68 changes: 39 additions & 29 deletions wcferry/msg_client.go
Original file line number Diff line number Diff line change
@@ -1,66 +1,76 @@
package wcferry

import (
"sync"

"github.com/opentdp/go-helper/logman"
"github.com/opentdp/go-helper/recovery"
"github.com/opentdp/go-helper/strutil"
)

type MsgClient struct {
*pbSocket // RPC 客户端
callbacks map[string]MsgCallback // 推送函数列表
mu sync.Mutex // 互斥锁
consumer map[string]MsgConsumer // 消费者
}

// 消息回调函数
type MsgCallback func(msg *WxMsg)
type MsgConsumer func(msg *WxMsg)

// 关闭 RPC 连接
// param ks 消息接收器标识,空则关闭所有
// return error 错误信息
func (c *MsgClient) Destroy(ks ...string) error {
if len(c.callbacks) > 0 && len(ks) > 0 {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.consumer) > 0 && len(ks) > 0 {
for _, k := range ks {
delete(c.callbacks, k)
delete(c.consumer, k)
}
if len(c.callbacks) > 0 {
if len(c.consumer) > 0 {
return nil
}
}
// 关闭消息推送
c.callbacks = nil
c.consumer = nil
return c.close()
}

// 创建消息接收器
// param cb MsgCallback 消息回调函数
// param cb MsgConsumer 消息回调函数
// return string 接收器唯一标识
func (c *MsgClient) Register(cb MsgCallback) (string, error) {
func (c *MsgClient) Register(cb MsgConsumer) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
k := strutil.Rand(16)
if c.callbacks == nil {
if c.consumer == nil {
if err := c.init(0); err != nil {
logman.Error("msg receiver", "error", err)
logman.Error("msg consumer", "error", err)
return "", err
}
c.callbacks = map[string]MsgCallback{
k: cb,
}
go func() {
defer c.Destroy()
defer recovery.Handler()
for len(c.callbacks) > 0 {
if resp, err := c.recv(); err == nil {
msg := resp.GetWxmsg()
for _, f := range c.callbacks {
f(msg) // 推送消息
}
} else {
logman.Error("msg receiver", "error", err)
}
}
logman.Warn("msg receiver stopped")
}()
c.consumer = map[string]MsgConsumer{k: cb}
go c.runner()
} else {
c.callbacks[k] = cb
c.consumer[k] = cb
}
return k, nil
}

// 消息推送执行者
func (c *MsgClient) runner() {
defer recovery.Handler()
defer c.Destroy()
// 接收消息
for len(c.consumer) > 0 {
if resp, err := c.recv(); err == nil {
msg := resp.GetWxmsg()
for _, f := range c.consumer {
f(msg) // 推送消息
}
} else {
logman.Error("msg consumer", "error", err)
}
}
// 连接断开
logman.Warn("msg consumer stopped")
}

0 comments on commit ddce2fb

Please sign in to comment.