From ddce2fbddb36b3f61974c66c9a61a5c1876b71f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=A5=E6=B5=B7?= Date: Sat, 9 Mar 2024 22:07:11 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=B6=88=E6=81=AF=E6=8E=A5=E6=94=B6?= =?UTF-8?q?=E5=99=A8=E5=B9=B6=E5=8F=91=E5=AF=BC=E8=87=B4=E7=9A=84=E5=BC=82?= =?UTF-8?q?=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- httpd/wcfrest/receiver_url.go | 2 +- httpd/wcfrest/receiver_ws.go | 2 +- wcferry/client.go | 8 ++--- wcferry/msg_client.go | 68 ++++++++++++++++++++--------------- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/httpd/wcfrest/receiver_url.go b/httpd/wcfrest/receiver_url.go index 3e424570..13d52ab6 100644 --- a/httpd/wcfrest/receiver_url.go +++ b/httpd/wcfrest/receiver_url.go @@ -12,7 +12,7 @@ import ( var urlReceiverList = map[string]string{} -func urlReciever(url string) wcferry.MsgCallback { +func urlReciever(url string) wcferry.MsgConsumer { return func(msg *wcferry.WxMsg) { ret := wcferry.ParseWxMsg(msg) diff --git a/httpd/wcfrest/receiver_ws.go b/httpd/wcfrest/receiver_ws.go index c71127aa..00e1b67e 100644 --- a/httpd/wcfrest/receiver_ws.go +++ b/httpd/wcfrest/receiver_ws.go @@ -10,7 +10,7 @@ import ( "github.com/opentdp/wechat-rest/wcferry" ) -func socketReceiver(ws *websocket.Conn) wcferry.MsgCallback { +func socketReceiver(ws *websocket.Conn) wcferry.MsgConsumer { mu := sync.Mutex{} diff --git a/wcferry/client.go b/wcferry/client.go index 4d80c706..c07abdb0 100644 --- a/wcferry/client.go +++ b/wcferry/client.go @@ -51,10 +51,10 @@ func (c *Client) Connect() error { // 启动消息接收器 // param pyq bool 是否接收朋友圈消息 -// param cb MsgCallback 消息回调函数,可选参数 +// param cb MsgConsumer 消息回调函数,可选参数 // return string 接收器唯一标识 -func (c *Client) EnrollReceiver(pyq bool, cb MsgCallback) (string, error) { - if c.MsgClient.callbacks == nil { +func (c *Client) EnrollReceiver(pyq bool, cb MsgConsumer) (string, error) { + if c.MsgClient.consumer == nil { if c.CmdClient.EnableMsgReciver(true) != 0 { return "", errors.New("failed to enable msg server") } @@ -68,7 +68,7 @@ func (c *Client) EnrollReceiver(pyq bool, cb MsgCallback) (string, error) { // return error 错误信息 func (c *Client) DisableReceiver(ks ...string) error { err := c.MsgClient.Destroy(ks...) - if c.MsgClient.callbacks == nil { + if c.MsgClient.consumer == nil { if c.CmdClient.DisableMsgReciver() != 0 { return errors.New("failed to disable msg server") } diff --git a/wcferry/msg_client.go b/wcferry/msg_client.go index eebd83f8..332c0a47 100644 --- a/wcferry/msg_client.go +++ b/wcferry/msg_client.go @@ -1,6 +1,8 @@ package wcferry import ( + "sync" + "github.com/opentdp/go-helper/logman" "github.com/opentdp/go-helper/recovery" "github.com/opentdp/go-helper/strutil" @@ -8,59 +10,67 @@ import ( type MsgClient struct { *pbSocket // RPC 客户端 - callbacks map[string]MsgCallback // 推送函数列表 + mu sync.Mutex // 互斥锁 + consumer map[string]MsgConsumer // 消费者 } // 消息回调函数 -type MsgCallback func(msg *WxMsg) +type MsgConsumer func(msg *WxMsg) // 关闭 RPC 连接 // param ks 消息接收器标识,空则关闭所有 // return error 错误信息 func (c *MsgClient) Destroy(ks ...string) error { - if len(c.callbacks) > 0 && len(ks) > 0 { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.consumer) > 0 && len(ks) > 0 { for _, k := range ks { - delete(c.callbacks, k) + delete(c.consumer, k) } - if len(c.callbacks) > 0 { + if len(c.consumer) > 0 { return nil } } // 关闭消息推送 - c.callbacks = nil + c.consumer = nil return c.close() } // 创建消息接收器 -// param cb MsgCallback 消息回调函数 +// param cb MsgConsumer 消息回调函数 // return string 接收器唯一标识 -func (c *MsgClient) Register(cb MsgCallback) (string, error) { +func (c *MsgClient) Register(cb MsgConsumer) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() k := strutil.Rand(16) - if c.callbacks == nil { + if c.consumer == nil { if err := c.init(0); err != nil { - logman.Error("msg receiver", "error", err) + logman.Error("msg consumer", "error", err) return "", err } - c.callbacks = map[string]MsgCallback{ - k: cb, - } - go func() { - defer c.Destroy() - defer recovery.Handler() - for len(c.callbacks) > 0 { - if resp, err := c.recv(); err == nil { - msg := resp.GetWxmsg() - for _, f := range c.callbacks { - f(msg) // 推送消息 - } - } else { - logman.Error("msg receiver", "error", err) - } - } - logman.Warn("msg receiver stopped") - }() + c.consumer = map[string]MsgConsumer{k: cb} + go c.runner() } else { - c.callbacks[k] = cb + c.consumer[k] = cb } return k, nil } + +// 消息推送执行者 +func (c *MsgClient) runner() { + defer recovery.Handler() + defer c.Destroy() + // 接收消息 + for len(c.consumer) > 0 { + if resp, err := c.recv(); err == nil { + msg := resp.GetWxmsg() + for _, f := range c.consumer { + f(msg) // 推送消息 + } + } else { + logman.Error("msg consumer", "error", err) + } + } + // 连接断开 + logman.Warn("msg consumer stopped") +}