diff --git a/.gitignore b/.gitignore index c0ecedd..873555e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ .vscode/ .idea/ .ruff_cache/ +.DS_Store # config.ini 配置文件 config.ini diff --git a/main.py b/main.py index f2acd08..f736a30 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import uvicorn from loguru import logger +import wechatter.database as db import wechatter.utils.file_manager as fm from wechatter.init_logger import init_logger @@ -17,7 +18,6 @@ def main(): # 为了让此文件的 config 模块是首次导入,下面这些模块需要放到 config 导入之后 from wechatter.app.app import app from wechatter.bot.bot_info import BotInfo - from wechatter.sqlite.sqlite_manager import SqliteManager # isort: on BotInfo.update_name(config.bot_name) @@ -25,11 +25,8 @@ def main(): fm.check_and_create_folder("data/qrcodes") fm.check_and_create_folder("data/todos") fm.check_and_create_folder("data/text_image") - # 创建文件 - fm.check_and_create_file("data/wechatter.sqlite") - # 创建数据库表 - sqlite_manager = SqliteManager("data/wechatter.sqlite") - sqlite_manager.excute_folder("wechatter/sqlite/sqls") + + db.create_tables() logger.info("WeChatter 启动成功!") # 启动uvicorn diff --git a/pyproject.toml b/pyproject.toml index 7d60e80..22d2b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,4 @@ docstring-code-format = true line_length = 88 multi_line_output = 3 include_trailing_comma = true +combine_as_imports = true diff --git a/requirements.txt b/requirements.txt index 84470b1..dd7ccf3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,5 @@ apscheduler==3.10.4 pre-commit==3.5.0 loguru==0.7.2 tenacity==8.2.3 +pydantic==2.5.3 +sqlalchemy==2.0.25 diff --git a/tests/commands/test_bili_hot/test_bili_hot.py b/tests/commands/test_bili_hot/test_bili_hot.py index 0090e2c..2dd135a 100644 --- a/tests/commands/test_bili_hot/test_bili_hot.py +++ b/tests/commands/test_bili_hot/test_bili_hot.py @@ -1,13 +1,14 @@ -import unittest import json +import unittest + from wechatter.commands._commands import bili_hot class TestBiliHotCommand(unittest.TestCase): def setUp(self): - with open('tests/commands/test_bili_hot/bili_hot_response.json') as f: + with open("tests/commands/test_bili_hot/bili_hot_response.json") as f: self.r_json = json.load(f) - self.bili_hot_list = self.r_json['data']['list'] + self.bili_hot_list = self.r_json["data"]["list"] def test_extract_bili_hot_data_success(self): result = bili_hot._extract_bili_hot_data(self.r_json) @@ -18,12 +19,10 @@ def test_extract_bili_hot_data_failure(self): bili_hot._extract_bili_hot_data({}) def test_generate_bili_hot_message_success(self): - result = bili_hot._generate_bili_hot_message(self.r_json['data']['list']) - true_result = ( - "1. 复旦教师杀害学院书记被判死缓\n2. 免费领取心魔\n3. 扎克伯格道歉\n4. 易烊千玺当选中国影协理事\n5. kei和marin分手\n6. 上海禁止网约车在浦东机场运营\n7. 中国豆浆机在韩国热销\n8. 何同学工作室首次公开\n9. 官方称确定国花时机未成熟\n10. Aimer参演原神新春会\n11. 总台龙年春晚动画宣传片\n12. 只解冻刘德华太保守了\n13. 烟花引发住宅起火致老人去世\n14. 750万人在等尊严死\n15. 坠亡姐弟生母被送医院\n16. 不绵之夜\n17. 清华宣布脑机接口重大突破\n18. Uzi 使用替身攻击\n19. 捡雪吃患病连烧八天\n20. 幻塔EVA联动明日香登场" - ) + result = bili_hot._generate_bili_hot_message(self.r_json["data"]["list"]) + true_result = "1. 复旦教师杀害学院书记被判死缓\n2. 免费领取心魔\n3. 扎克伯格道歉\n4. 易烊千玺当选中国影协理事\n5. kei和marin分手\n6. 上海禁止网约车在浦东机场运营\n7. 中国豆浆机在韩国热销\n8. 何同学工作室首次公开\n9. 官方称确定国花时机未成熟\n10. Aimer参演原神新春会\n11. 总台龙年春晚动画宣传片\n12. 只解冻刘德华太保守了\n13. 烟花引发住宅起火致老人去世\n14. 750万人在等尊严死\n15. 坠亡姐弟生母被送医院\n16. 不绵之夜\n17. 清华宣布脑机接口重大突破\n18. Uzi 使用替身攻击\n19. 捡雪吃患病连烧八天\n20. 幻塔EVA联动明日香登场" self.assertIn(true_result, result) def test_generate_bili_hot_message_empty_list(self): result = bili_hot._generate_bili_hot_message([]) - self.assertEqual(result, '暂无Bilibili热搜') + self.assertEqual(result, "暂无Bilibili热搜") diff --git a/tests/commands/test_douyin_hot/test_douyin_hot.py b/tests/commands/test_douyin_hot/test_douyin_hot.py index f3e1097..b44ffcb 100644 --- a/tests/commands/test_douyin_hot/test_douyin_hot.py +++ b/tests/commands/test_douyin_hot/test_douyin_hot.py @@ -1,13 +1,14 @@ -import unittest import json +import unittest + from wechatter.commands._commands import douyin_hot class TestDouyinHotCommand(unittest.TestCase): def setUp(self): - with open('tests/commands/test_douyin_hot/douyin_hot_response.json') as f: + with open("tests/commands/test_douyin_hot/douyin_hot_response.json") as f: self.r_json = json.load(f) - self.douyin_hot_list = self.r_json['word_list'] + self.douyin_hot_list = self.r_json["word_list"] def test_extract_douyin_hot_data_success(self): result = douyin_hot._extract_douyin_hot_data(self.r_json) @@ -18,10 +19,10 @@ def test_extract_douyin_hot_data_failure(self): douyin_hot._extract_douyin_hot_data({}) def test_generate_douyin_hot_message_success(self): - result = douyin_hot._generate_douyin_hot_message(self.r_json['word_list']) + result = douyin_hot._generate_douyin_hot_message(self.r_json["word_list"]) true_result = "1. 董宇辉对刘德华说想演兵马俑\n2. 今年春联是自己写的\n3. 中国同23国全面互免签证\n4. 暗夜变装挑战\n5. 晒出你的新春红\n6. 相亲相爱接力挑战\n7. 诀别书的正确打开方式\n8. 2月第一天\n9. 北京产权交易所澄清声明\n10. 贵州有多钟爱办酒席\n11. 各地人这么做一定有原因\n12. 韩国瑜当选台立法机构负责人\n13. 一起跳甜妹手势舞\n14. 上海楼房凌晨坍塌 多方回应\n15. 刘德华宁浩红毯先生今晚直播\n16. 没有同桌我可怎么办啊\n17. 河北一保安阻止奔驰加塞被顶撞\n18. 网友过年爱上组养生局了\n19. 召集全抖音晚8找乐子\n20. 春运期间天气预测" self.assertIn(true_result, result) def test_generate_douyin_hot_message_empty_list(self): result = douyin_hot._generate_douyin_hot_message([]) - self.assertEqual(result, '暂无抖音热搜') + self.assertEqual(result, "暂无抖音热搜") diff --git a/tests/commands/test_people_daily/test_people_daily.py b/tests/commands/test_people_daily/test_people_daily.py index f6db6e7..360da4e 100644 --- a/tests/commands/test_people_daily/test_people_daily.py +++ b/tests/commands/test_people_daily/test_people_daily.py @@ -1,11 +1,15 @@ import unittest + from wechatter.commands._commands import people_daily class TestPeopleDailyCommand(unittest.TestCase): def test_get_people_daily_url_success(self): result = people_daily.get_people_daily_url("2024010901") - self.assertEqual(result, "http://paper.people.com.cn/rmrb/images/2024-01/09/01/rmrb2024010901.pdf") + self.assertEqual( + result, + "http://paper.people.com.cn/rmrb/images/2024-01/09/01/rmrb2024010901.pdf", + ) def test_get_people_daily_value_error(self): with self.assertRaises(ValueError): diff --git a/tests/commands/test_qrcode/test_qrcode.py b/tests/commands/test_qrcode/test_qrcode.py index da743ff..2d2898e 100644 --- a/tests/commands/test_qrcode/test_qrcode.py +++ b/tests/commands/test_qrcode/test_qrcode.py @@ -1,7 +1,9 @@ import os import unittest from unittest.mock import patch + import qrcode as qrc + from wechatter.commands._commands import qrcode @@ -15,7 +17,7 @@ def tearDown(self): os.remove(self.path) def test_generate_qrcode_success(self): - result = qrcode.generate_qrcode('https://www.baidu.com') + result = qrcode.generate_qrcode("https://www.baidu.com") self.assertIsInstance(result, qrc.image.pil.PilImage) def test_save_qrcode_attribute_error(self): @@ -26,7 +28,7 @@ def test_save_qrcode_successfully(self): qrcode.save_qrcode(self.img, self.path) self.assertTrue(os.path.exists(self.path)) - @patch('os.access', return_value=False) + @patch("os.access", return_value=False) def test_save_qrcode_permission_error(self, mock_access): with self.assertRaises(PermissionError): - qrcode.save_qrcode(self.img, self.path) \ No newline at end of file + qrcode.save_qrcode(self.img, self.path) diff --git a/tests/commands/test_today_in_history/test_today_in_history.py b/tests/commands/test_today_in_history/test_today_in_history.py index ee59c39..2e1efde 100644 --- a/tests/commands/test_today_in_history/test_today_in_history.py +++ b/tests/commands/test_today_in_history/test_today_in_history.py @@ -1,14 +1,16 @@ import json import unittest + from wechatter.commands._commands import today_in_history class TestTodayInHistoryCommand(unittest.TestCase): - def setUp(self): - with open('tests/commands/test_today_in_history/today_in_history_response.json') as f: + with open( + "tests/commands/test_today_in_history/today_in_history_response.json" + ) as f: self.tih_response = json.load(f) - self.tih_list = self.tih_response['data'] + self.tih_list = self.tih_response["data"] def test_extract_today_in_history_data_success(self): result = today_in_history._extract_today_in_history_data(self.tih_response) diff --git a/tests/commands/test_trivia/test_trivia.py b/tests/commands/test_trivia/test_trivia.py index c6cb192..3806d23 100644 --- a/tests/commands/test_trivia/test_trivia.py +++ b/tests/commands/test_trivia/test_trivia.py @@ -1,17 +1,18 @@ -import unittest import json +import unittest + from requests import Response + from wechatter.commands._commands import trivia from wechatter.exceptions import Bs4ParsingError class TestTriviaCommand(unittest.TestCase): - def setUp(self): with open("tests/commands/test_trivia/trivia_response.html.test") as f: r_html = f.read() self.response = Response() - self.response._content = r_html.encode('utf-8') + self.response._content = r_html.encode("utf-8") with open("tests/commands/test_trivia/trivia_data.json") as f: self.trivia_list = json.load(f) diff --git a/tests/commands/test_weibo_hot/test_weibo_hot.py b/tests/commands/test_weibo_hot/test_weibo_hot.py index 0610ebc..0063cb7 100644 --- a/tests/commands/test_weibo_hot/test_weibo_hot.py +++ b/tests/commands/test_weibo_hot/test_weibo_hot.py @@ -1,12 +1,12 @@ -import unittest import json +import unittest + from wechatter.commands._commands import weibo_hot class TestWeiboHotCommand(unittest.TestCase): - def setUp(self): - with open('tests/commands/test_weibo_hot/weibo_hot_response.json') as f: + with open("tests/commands/test_weibo_hot/weibo_hot_response.json") as f: self.r_json = json.load(f) self.weibo_hot_list = self.r_json["data"]["cards"][0]["card_group"][:20] @@ -25,4 +25,4 @@ def test_generate_weibo_hot_message_success(self): def test_generate_weibo_hot_message_empty_list(self): result = weibo_hot._generate_weibo_hot_message([]) - self.assertEqual(result, '微博热搜列表为空') + self.assertEqual(result, "微博热搜列表为空") diff --git a/tests/commands/test_zhihu_hot/test_zhihu_hot.py b/tests/commands/test_zhihu_hot/test_zhihu_hot.py index a27e203..7116130 100644 --- a/tests/commands/test_zhihu_hot/test_zhihu_hot.py +++ b/tests/commands/test_zhihu_hot/test_zhihu_hot.py @@ -1,13 +1,14 @@ -import unittest import json +import unittest + from wechatter.commands._commands import zhihu_hot class TestZhihuHotCommand(unittest.TestCase): def setUp(self): - with open('tests/commands/test_zhihu_hot/zhihu_hot_response.json') as f: + with open("tests/commands/test_zhihu_hot/zhihu_hot_response.json") as f: self.r_json = json.load(f) - self.zhihu_hot_list = self.r_json['data'] + self.zhihu_hot_list = self.r_json["data"] def test_extract_zhihu_hot_data_success(self): result = zhihu_hot._extract_zhihu_hot_data(self.r_json) @@ -19,11 +20,9 @@ def test_extract_zhihu_hot_data_failure(self): def test_generate_zhihu_hot_message_success(self): result = zhihu_hot._generate_zhihu_hot_message(self.zhihu_hot_list) - true_result = ( - "1. 得州州长警告拜登称,若「联邦化」得州国民警卫队将是政治错误,如何解读?该事件将会如何演变?\n2. 贵州一村民办酒席被挂工作证人员往食物撒盐,当地通报「已致歉」,如何评价此事?\n3. 美国联合包裹运送服务公司宣布将裁员 1.2 万人,将节省约 10 亿美元成本,哪些信息值得关注?\n4. 汞比金多一个质子,理论上汞比金更难形成,为什么汞比金便宜?\n5. 沙特伊朗等五国已正式成为金砖成员国,34 个国家提出书面申请,对金砖扩容做何展望?哪些信息值得关注?\n6. 23-24 赛季 NBA快船 125:109 奇才,如何评价这场比赛?\n7. 戴尔服务中国公司拟注销,哪些信息值得关注?\n8. 如何看待 2 月 1 日 A 股市场行情?\n9. 如何评价“祸不及家人的前提是惠不及家人”这个观点?\n10. 明明知道熬夜对身体不好,可就是戒不掉,我该怎么办?\n11. 「复旦研究员刺杀书记案」一审宣判被告死缓,犯故意杀人罪,哪些信息值得关注?如何看待此案?\n12. 白所成等10名缅北重大犯罪嫌疑人被押解回国,缅北四大家族尽数覆灭,哪些信息值得关注?\n13. 韩国瑜当选台新任立法机构负责人,释放了哪些信息?这对台湾当局执政有何影响?\n14. 财政部表示基本养老保险基金已累计结余近 6 万亿元,养老金按时足额发放有保证,哪些信息值得关注?\n15. 河南暴雪来临,洛阳等地已开始下雪,目前情况如何?会对春运出行带来哪些影响?\n16. 同事听到我用“免贵姓…”回答“您贵姓?”这个问题的时候全都笑了,是哪里出了问题?\n17. 如何评价《崩坏:星穹铁道》千星纪游PV:「旧梦重温」?\n18. 23-24 赛季英超利物浦 4:1 切尔西,如何评价这场比赛?\n19. 孩子应该做家务、做饭吗?\n20. 你在跑步路上遇到的最大的阻碍是什么?" - ) + true_result = "1. 得州州长警告拜登称,若「联邦化」得州国民警卫队将是政治错误,如何解读?该事件将会如何演变?\n2. 贵州一村民办酒席被挂工作证人员往食物撒盐,当地通报「已致歉」,如何评价此事?\n3. 美国联合包裹运送服务公司宣布将裁员 1.2 万人,将节省约 10 亿美元成本,哪些信息值得关注?\n4. 汞比金多一个质子,理论上汞比金更难形成,为什么汞比金便宜?\n5. 沙特伊朗等五国已正式成为金砖成员国,34 个国家提出书面申请,对金砖扩容做何展望?哪些信息值得关注?\n6. 23-24 赛季 NBA快船 125:109 奇才,如何评价这场比赛?\n7. 戴尔服务中国公司拟注销,哪些信息值得关注?\n8. 如何看待 2 月 1 日 A 股市场行情?\n9. 如何评价“祸不及家人的前提是惠不及家人”这个观点?\n10. 明明知道熬夜对身体不好,可就是戒不掉,我该怎么办?\n11. 「复旦研究员刺杀书记案」一审宣判被告死缓,犯故意杀人罪,哪些信息值得关注?如何看待此案?\n12. 白所成等10名缅北重大犯罪嫌疑人被押解回国,缅北四大家族尽数覆灭,哪些信息值得关注?\n13. 韩国瑜当选台新任立法机构负责人,释放了哪些信息?这对台湾当局执政有何影响?\n14. 财政部表示基本养老保险基金已累计结余近 6 万亿元,养老金按时足额发放有保证,哪些信息值得关注?\n15. 河南暴雪来临,洛阳等地已开始下雪,目前情况如何?会对春运出行带来哪些影响?\n16. 同事听到我用“免贵姓…”回答“您贵姓?”这个问题的时候全都笑了,是哪里出了问题?\n17. 如何评价《崩坏:星穹铁道》千星纪游PV:「旧梦重温」?\n18. 23-24 赛季英超利物浦 4:1 切尔西,如何评价这场比赛?\n19. 孩子应该做家务、做饭吗?\n20. 你在跑步路上遇到的最大的阻碍是什么?" self.assertIn(true_result, result) def test_generate_zhihu_hot_message_empty_list(self): result = zhihu_hot._generate_zhihu_hot_message([]) - self.assertEqual(result, '暂无知乎热搜') + self.assertEqual(result, "暂无知乎热搜") diff --git a/wechatter/app/routers/wechat.py b/wechatter/app/routers/wechat.py index 788f007..d526eee 100644 --- a/wechatter/app/routers/wechat.py +++ b/wechatter/app/routers/wechat.py @@ -7,11 +7,18 @@ import wechatter.config as config from wechatter.bot.bot_info import BotInfo from wechatter.commands import commands +from wechatter.database import ( + Group as DbGroup, + Message as DbMessage, + Person as DbPerson, + make_db_session, +) from wechatter.message import MessageHandler from wechatter.message_forwarder import MessageForwarder -from wechatter.models.message import Message +from wechatter.models.wechat import Message +from wechatter.models.wechat.group import Group +from wechatter.models.wechat.person import Person from wechatter.sender import notifier -from wechatter.sqlite.sqlite_manager import SqliteManager router = APIRouter() @@ -21,16 +28,18 @@ async def recv_wechat_msg( type: str = Form(), content: Union[UploadFile, str] = Form(), source: str = Form(), - isMentioned: str = Form(), - isSystemEvent: str = Form(), + is_mentioned: str = Form(alias="isMentioned"), + is_system_event: str = Form(alias="isSystemEvent"), ): - """接收Docker转发过来的消息的接口""" + """ + 接收Docker转发过来的消息的接口 + """ # 更新机器人信息(id和name) BotInfo.update_from_source(source) # 判断是否是系统事件 - if isSystemEvent == "1": + if is_system_event == "1": logger.info(f"收到系统事件:{content}") handle_system_event(content) return @@ -42,24 +51,22 @@ async def recv_wechat_msg( # 解析命令 # 构造消息对象 - message = Message( + message = Message.from_api_msg( type=type, content=content, source=source, - is_mentioned=isMentioned, + is_mentioned=is_mentioned, ) + # 向群组表中添加该群组 + add_group(message.group) # 向用户表中添加该用户 - check_and_add_user( - user_id=message.source.p_info.id, - user_name=message.source.p_info.name, - user_alias=message.source.p_info.alias, - user_gender=message.source.p_info.gender, - ) + add_person(message.person) + # 向消息表中添加该消息 + message.id = add_message(message) + # TODO: 添加自己发送的消息,等待 wechatbot-webhook 支持 # DEBUG - print("==" * 20) print(str(message)) - print("==" * 20) if config.message_forwarding_enabled: MessageForwarder(config.message_forwarding_rule_list).forward_message(message) @@ -74,14 +81,16 @@ async def recv_wechat_msg( def handle_system_event(content: str) -> None: - """判断系统事件类型,并调用相应的函数""" + """ + 判断系统事件类型,并调用相应的函数 + """ content_dict: dict = json.loads(content) # 判断是否为机器人登录消息 if content_dict["event"] == "login": - print("机器人登录成功") + logger.info("机器人登录成功") notifier.notify_logged_in() elif content_dict["event"] == "logout": - print("机器人已退出登录") + logger.info("机器人已退出登录") notifier.notify_logged_out() elif content_dict["event"] == "error": pass @@ -89,29 +98,65 @@ def handle_system_event(content: str) -> None: pass -# TODO: 判断传入的参数和数据库中的数据是否一致,若不一致,则更新数据库中的数据 -def check_and_add_user( - user_id: str, user_name: str = "", user_alias: str = "", user_gender: int = -1 -) -> None: - """判断用户表中是否有该用户,若没有,则添加该用户""" - sqlite_manager = SqliteManager() - sql = "SELECT * FROM wx_users WHERE wx_id = ?" - result = sqlite_manager.fetch_one(sql, (user_id,)) - if result is not None: +def add_group(group: Group) -> None: + """ + 判断群组表中是否有该群组,若没有,则添加该群组 + """ + if group is None: return - # 该用户不存在,添加该用户 - gender = "unknown" - if user_gender == 1: - gender = "male" - elif user_gender == 0: - gender = "female" - sql = "INSERT INTO wx_users(wx_id, wx_name, wx_alias, wx_gender) VALUES(?, ?, ?, ?)" - sqlite_manager.insert( - "wx_users", - { - "wx_id": user_id, - "wx_name": user_name, - "wx_alias": user_alias, - "wx_gender": gender, - }, - ) + with make_db_session() as session: + _group = session.query(DbGroup).filter(DbGroup.id == group.id).first() + if _group is None: + _group = DbGroup.from_model(group) + session.add(_group) + # 逐个添加群组成员,若存在则更新 + for member in group.member_list: + _person = ( + session.query(DbPerson).filter(DbPerson.id == member.id).first() + ) + if _person is None: + _person = DbPerson.from_member_model(member) + session.add(_person) + session.commit() + logger.info(f"用户 {member.name} 已添加到数据库") + else: + # 更新用户信息 + _person.name = member.name + _person.alias = member.alias + session.commit() + + session.commit() + logger.info(f"群组 {group.name} 已添加到数据库") + else: + # 更新群组信息 + _group.update(group) + session.commit() + + +def add_person(person: Person) -> None: + """ + 判断用户表中是否有该用户,若没有,则添加该用户 + """ + with make_db_session() as session: + _person = session.query(DbPerson).filter(DbPerson.id == person.id).first() + if _person is None: + _person = DbPerson.from_model(person) + session.add(_person) + session.commit() + logger.info(f"用户 {person.name} 已添加到数据库") + else: + # 更新用户信息 + _person.update(person) + session.commit() + + +def add_message(message: Message) -> int: + """ + 添加消息到消息表 + """ + with make_db_session() as session: + _message = DbMessage.from_model(message) + session.add(_message) + session.commit() + logger.info(f"消息 {_message.id} 已添加到数据库") + return _message.id diff --git a/wechatter/commands/_commands/bili_hot.py b/wechatter/commands/_commands/bili_hot.py index e7cebc9..d1f96ea 100644 --- a/wechatter/commands/_commands/bili_hot.py +++ b/wechatter/commands/_commands/bili_hot.py @@ -3,7 +3,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request_json diff --git a/wechatter/commands/_commands/copilot_gpt4.py b/wechatter/commands/_commands/copilot_gpt4.py index cc98ffa..2f722e4 100644 --- a/wechatter/commands/_commands/copilot_gpt4.py +++ b/wechatter/commands/_commands/copilot_gpt4.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import List, Union from loguru import logger @@ -5,15 +6,20 @@ import wechatter.config as config import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.database import ( + GptChatInfo as DbGptChatInfo, + GptChatMessage as DbGptChatMessage, + make_db_session, +) +from wechatter.models.gpt import GptChatInfo +from wechatter.models.wechat import Person, SendTo from wechatter.sender import sender -from wechatter.sqlite.sqlite_manager import SqliteManager from wechatter.utils import post_request_json -from wechatter.utils.time import get_current_timestamp DEFAULT_TOPIC = "(对话进行中*)" -DEFAULT_MODEL = "gpt-4" -DEFAULT_CONVERSATIONS = [{"role": "system", "content": "你是一位乐于助人的助手"}] +# DEFAULT_MODEL = "gpt-4" +# TODO: 初始化对话,Prompt选择 +DEFAULT_CONVERSATION = [{"role": "system", "content": "你是一位乐于助人的助手"}] @command( @@ -21,8 +27,8 @@ keys=["gpt"], desc="使用GPT3.5进行对话。", ) -def gpt35_command_handler(to: SendTo, message: str = "") -> None: - _gptx("gpt-3.5-turbo", to, message) +def gpt35_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx("gpt-3.5-turbo", to, message, message_obj) @command( @@ -30,8 +36,10 @@ def gpt35_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-chats", "gpt对话记录"], desc="列出GPT3.5对话记录。", ) -def gpt35_chats_command_handler(to: SendTo, message: str = "") -> None: - _gptx_chats("gpt-3.5-turbo", to, message) +def gpt35_chats_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: + _gptx_chats("gpt-3.5-turbo", to, message, message_obj) @command( @@ -39,7 +47,9 @@ def gpt35_chats_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-record", "gpt记录"], desc="获取GPT3.5对话记录。", ) -def gpt35_record_command_handler(to: SendTo, message: str = "") -> None: +def gpt35_record_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_record("gpt-3.5-turbo", to, message) @@ -48,7 +58,9 @@ def gpt35_record_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-continue", "gpt继续"], desc="继续GPT3.5对话。", ) -def gpt35_continue_command_handler(to: SendTo, message: str = "") -> None: +def gpt35_continue_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_continue("gpt-3.5-turbo", to, message) @@ -57,8 +69,8 @@ def gpt35_continue_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4"], desc="使用GPT4进行对话。", ) -def gpt4_command_handler(to: SendTo, message: str = "") -> None: - _gptx("gpt-4", to, message) +def gpt4_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx("gpt-4", to, message, message_obj) @command( @@ -66,8 +78,8 @@ def gpt4_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-chats", "gpt4对话记录"], desc="列出GPT4对话记录。", ) -def gpt4_chats_command_handler(to: SendTo, message: str = "") -> None: - _gptx_chats("gpt-4", to, message) +def gpt4_chats_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx_chats("gpt-4", to, message, message_obj) @command( @@ -75,7 +87,9 @@ def gpt4_chats_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-record", "gpt4记录"], desc="获取GPT4对话记录。", ) -def gpt4_record_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_record_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_record("gpt-4", to, message) @@ -84,25 +98,29 @@ def gpt4_record_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-continue", "gpt4继续"], desc="继续GPT4对话。", ) -def gpt4_continue_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_continue_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_continue("gpt-4", to, message) # TODO: # 命令:/gpt4-remove -def gpt4_remove_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_remove_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: pass -def _gptx(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id +def _gptx(model: str, to: SendTo, message: str = "", message_obj=None) -> None: + person = to.person # 获取文件夹下最新的对话记录 - chat_info = CopilotGPT4.get_chating_chat_info(wx_id, model) + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) if message == "": # /gpt4 # 判断对话是否有效 sender.send_msg(to, "正在创建新对话...") - if chat_info is None or CopilotGPT4.is_chat_valid(chat_info): - CopilotGPT4.create_chat(wx_id=wx_id, model=model) + if chat_info is None or CopilotGPT4._is_chat_valid(chat_info): + CopilotGPT4.create_chat(person, model) logger.info("创建新对话成功") sender.send_msg(to, "创建新对话成功") return @@ -112,11 +130,13 @@ def _gptx(model: str, to: SendTo, message: str = "") -> None: # 如果没有对话记录,则创建新对话 sender.send_msg(to, f"正在调用 {model} 进行对话...") if chat_info is None: - chat_info = CopilotGPT4.create_chat(wx_id=wx_id, model=model) + chat_info = CopilotGPT4.create_chat(person, model) logger.info("无历史对话记录,创建新对话成功") sender.send_msg(to, "无历史对话记录,创建新对话成功") try: - response = CopilotGPT4.chat(chat_info, message) + response = CopilotGPT4.chat( + chat_info, message=message, message_obj=message_obj + ) logger.info(response) sender.send_msg(to, response) except Exception as e: @@ -125,22 +145,21 @@ def _gptx(model: str, to: SendTo, message: str = "") -> None: sender.send_msg(to, error_message) -def _gptx_chats(model: str, to: SendTo, message: str = "") -> None: - response = CopilotGPT4.get_chat_list_str(to.p_id, model) +def _gptx_chats(model: str, to: SendTo, message: str = "", message_obj=None) -> None: + response = CopilotGPT4.get_chat_list_str(to.person, model) sender.send_msg(to, response) -def _gptx_record(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id - chat_info = None +def _gptx_record(model: str, to: SendTo, message: str = ""): + person = to.person if message == "": # 获取当前对话的对话记录 - chat_info = CopilotGPT4.get_chating_chat_info(wx_id, model) + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) else: # 获取指定对话的对话记录 - chat_info = CopilotGPT4.get_chat_info(wx_id, model, int(message)) + chat_info = CopilotGPT4.get_chat_info(person, model, int(message)) if chat_info is None: - logger.waring("对话不存在") + logger.warning("对话不存在") sender.send_msg(to, "对话不存在") return response = CopilotGPT4.get_brief_conversation_str(chat_info) @@ -149,7 +168,7 @@ def _gptx_record(model: str, to: SendTo, message: str = "") -> None: def _gptx_continue(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id + person = to.person # 判断message是否为数字 if not message.isdigit(): logger.info("请输入对话记录编号") @@ -157,12 +176,12 @@ def _gptx_continue(model: str, to: SendTo, message: str = "") -> None: return sender.send_msg(to, f"正在切换到对话记录 {message}...") chat_info = CopilotGPT4.continue_chat( - wx_id=wx_id, model=model, chat_index=int(message) + person=person, model=model, chat_index=int(message) ) if chat_info is None: - waring_message = "选择历史对话失败,对话不存在" - logger.waring(waring_message) - sender.send_msg(to, waring_message) + warning_message = "选择历史对话失败,对话不存在" + logger.warning(warning_message) + sender.send_msg(to, warning_message) return response = CopilotGPT4.get_brief_conversation_str(chat_info) response += "====================\n" @@ -171,356 +190,245 @@ def _gptx_continue(model: str, to: SendTo, message: str = "") -> None: sender.send_msg(to, response) -class ChatInfo: - """对话信息(与 copilot_chats 表对应)""" - - def __init__( - self, - wx_id: str = "", - chat_created_time: int = -1, - chat_talk_time: int = -1, - chat_topic: str = DEFAULT_TOPIC, - chat_model: str = DEFAULT_MODEL, - conversations: List[dict] = DEFAULT_CONVERSATIONS, - is_chating: bool = False, - chat_id: int = -1, - ): - self.chat_id = chat_id - self.wx_id = wx_id - self.chat_created_time = chat_created_time - self.chat_talk_time = chat_talk_time - self.chat_topic = chat_topic - self.chat_model = chat_model - self.conversations = conversations - self.is_chating = is_chating - - @property - def has_topic(self) -> bool: - """是否有对话主题""" - if self.chat_topic == DEFAULT_TOPIC: - return False - return True - - @property - def dict(self) -> dict: - """将对象转为字典(删去 conversations 字段)""" - chat_info_dict = self.__dict__.copy() - chat_info_dict.pop("conversations") - return chat_info_dict - - class CopilotGPT4: - """Copilot-GPT4""" - api = f"{config.cp_gpt4_api_host}:{config.cp_gpt4_port}/v1/chat/completions" bearer_token = "Bearer " + config.cp_token save_path = pm.get_abs_path("data/copilot_gpt4/chats/") @staticmethod - def create_chat(wx_id: str, model: str) -> ChatInfo: - """创建一个新的对话""" + def create_chat(person: Person, model: str) -> GptChatInfo: + """ + 创建一个新的对话 + :param person: 用户 + :param model: 模型 + :return: 新的对话信息 + """ # 生成上一次对话的主题 - CopilotGPT4._save_chating_chat_topic(wx_id, model) - CopilotGPT4._set_all_chats_unchating(wx_id, model) - timestamp = get_current_timestamp() - chat_info = ChatInfo( - wx_id=wx_id, - chat_created_time=timestamp, - chat_talk_time=timestamp, - chat_model=model, - is_chating=True, - ) - # 保存对话记录 - chat_info_dict = chat_info.dict - # 删去 chat_id 字段,让数据库自动生成 - chat_info_dict.pop("chat_id") - sqlm = SqliteManager() - sqlm.insert("copilot_chats", chat_info_dict) - # 获取 SQLite 自动生成的 chat_id - sql = ( - "SELECT chat_id " - "FROM copilot_chats " - "WHERE wx_id = ? AND is_chating = TRUE AND chat_model = ? " + CopilotGPT4._save_chatting_chat_topic(person, model) + CopilotGPT4._set_all_chats_not_chatting(person, model) + gpt_chat_info = GptChatInfo( + person=person, + model=model, + topic=DEFAULT_TOPIC, + is_chatting=True, ) - result = sqlm.fetch_one(sql, (wx_id, model)) - chat_info.chat_id = result[0] - # 插入对话记录 - sqlm.insert( - "chat_conversations", - { - "chat_id": chat_info.chat_id, - "conversation_role": chat_info.conversations[0]["role"], - "conversation_content": chat_info.conversations[0]["content"], - "conversation_timestamp": get_current_timestamp(), - }, - ) - return chat_info + with make_db_session() as session: + _gpt_chat_info = DbGptChatInfo.from_model(gpt_chat_info) + session.add(_gpt_chat_info) + session.commit() + # 获取 SQLite 自动生成的 chat_id + session.refresh(_gpt_chat_info) + gpt_chat_info = _gpt_chat_info.to_model() + return gpt_chat_info @staticmethod - def continue_chat(wx_id: str, model: str, chat_index: int) -> Union[ChatInfo, None]: - """继续对话,从对话记录文件中读取对话记录 - :param wx_id: 微信用户ID + def continue_chat( + person: Person, model: str, chat_index: int + ) -> Union[GptChatInfo, None]: + """ + 继续对话,选择历史对话 + :param person: 用户 + :param model: 模型 :param chat_index: 对话记录索引(从1开始) - :return: 简略的对话记录 + :return: 对话信息 """ # 读取对话记录文件 - chat_info = CopilotGPT4.get_chat_info(wx_id, model, chat_index) + chat_info = CopilotGPT4.get_chat_info(person, model, chat_index) if chat_info is None: return None - chating_chat_info = CopilotGPT4.get_chating_chat_info(wx_id, model) - if not CopilotGPT4.is_chat_valid(chating_chat_info): - # 如果对话无效,则删除该对话记录后再继续对话 - CopilotGPT4._delete_chat(wx_id, chating_chat_info.chat_id) - else: - # 生成上一次对话的主题 - CopilotGPT4._save_chating_chat_topic(wx_id, model) - CopilotGPT4._set_chating_chat(wx_id, model, chat_info.chat_id) + chatting_chat_info = CopilotGPT4.get_chatting_chat_info(person, model) + if chatting_chat_info: + if not CopilotGPT4._is_chat_valid(chatting_chat_info): + # 如果对话无效,则删除该对话记录后再继续对话 + CopilotGPT4._delete_chat(chatting_chat_info) + else: + # 生成上一次对话的主题 + CopilotGPT4._save_chatting_chat_topic(person, model) + CopilotGPT4._set_chatting_chat(person, model, chat_info) return chat_info @staticmethod - def _set_chating_chat(wx_id: str, model: str, chat_id: int) -> None: - """设置正在进行中的对话记录""" - # 先将所有对话记录的 is_chating 字段设置为 False - CopilotGPT4._set_all_chats_unchating(wx_id, model) - sqlm = SqliteManager() - sqlm.update( - "copilot_chats", - {"is_chating": True}, - f"wx_id = '{wx_id}' AND chat_id = {chat_id}", - ) + def _set_chatting_chat(person: Person, model: str, chat_info: GptChatInfo) -> None: + """ + 设置正在进行中的对话记录 + """ + # 先将所有对话记录的 is_chatting 字段设置为 False + CopilotGPT4._set_all_chats_not_chatting(person, model) + with make_db_session() as session: + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + if chat_info is None: + logger.error("对话记录不存在") + raise ValueError("对话记录不存在") + chat_info.is_chatting = True + session.commit() @staticmethod - def _delete_chat(wx_id: str, chat_id: int) -> None: - """删除对话记录""" - sqlm = SqliteManager() - # 先删除对话记录 - sqlm.delete("chat_conversations", f"chat_id = {chat_id}") - # 再删除对话元数据 - sqlm.delete("copilot_chats", f"wx_id = '{wx_id}' AND chat_id = {chat_id}") + def _delete_chat(chat_info: GptChatInfo) -> None: + """ + 删除对话记录 + """ + with make_db_session() as session: + session.query(DbGptChatMessage).filter_by(gpt_chat_id=chat_info.id).delete() + session.query(DbGptChatInfo).filter_by(id=chat_info.id).delete() + session.commit() @staticmethod - def get_brief_conversation_str(chat_info: ChatInfo) -> str: - """获取对话记录的字符串""" - conversation_str = f"✨==={chat_info.chat_topic}===✨\n" - if chat_info == []: - conversation_str += "无对话记录" - return conversation_str - for conv in chat_info.conversations[-10:]: - content = conv["content"][:30] - if len(conv["content"]) > 30: - content += "..." - if conv["role"] == "system": - conversation_str += f"⭐️:{content}\n" - elif conv["role"] == "assistant": - conversation_str += f"🤖:{content}\n" - elif conv["role"] == "user": + def get_brief_conversation_str(chat_info: GptChatInfo) -> str: + """ + 获取对话记录的字符串 + :param chat_info: 对话记录 + :return: 对话记录字符串 + """ + with make_db_session() as session: + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + if chat_info is None: + logger.error("对话记录不存在") + raise ValueError("对话记录不存在") + conversation_str = f"✨==={chat_info.topic}===✨\n" + if not chat_info.gpt_chat_messages: + conversation_str += " 无对话记录" + return conversation_str + for msg in chat_info.gpt_chat_messages: + content: str = msg.message.content + # 合并成一行,提升观感 + content = content.replace("\n", "") + # 去掉命令前缀和命令关键词 + content = content[content.find(" ") + 1 :][:30] + response = msg.gpt_response[:30] + response = response.replace("\n", "") + if len(msg.message.content) > 30: + content += "..." + if len(msg.gpt_response) > 30: + response += "..." conversation_str += f"💬:{content}\n" - return conversation_str - - @staticmethod - def _get_brief_conversation_content(conversation: List) -> List: - """获取简略的对话记录的内容""" - content_list = [] - for conv in conversation[1:]: - if len(conv["content"]) > 20: - conv["content"] = conv["content"][:20] + "..." - content_list.append(conv["content"]) - return content_list - - @staticmethod - def _set_all_chats_unchating(wx_id: str, model: str) -> None: - """将所有对话记录的 is_chating 字段设置为 False""" - sqlm = SqliteManager() - sqlm.update( - "copilot_chats", - {"is_chating": False}, - f"wx_id = '{wx_id}' AND chat_model = '{model}'", - ) + conversation_str += f"🤖:{response}\n" + return conversation_str @staticmethod - def is_chat_valid(chat_info: ChatInfo) -> bool: - """判断对话是否有效""" - # 通过 conversation 长度判断对话是否有效 - if len(chat_info.conversations) <= 1: - return False - return True + def _set_all_chats_not_chatting(person: Person, model: str) -> None: + """ + 将所有对话记录的 is_chatting 字段设置为 False + """ + with make_db_session() as session: + session.query(DbGptChatInfo).filter_by( + person_id=person.id, model=model + ).update({"is_chatting": False}) + session.commit() @staticmethod - def _list_chat_info(wx_id: str, model: str) -> List: - """列出用户的所有对话记录""" - # 读取对话记录文件夹,按照 chat_talk_time 字段倒序排序,取前20个 - sqlm = SqliteManager() - sql = ( - "SELECT chat_id, wx_id, chat_created_time, chat_talk_time, chat_topic, chat_model, is_chating " - "FROM copilot_chats " - "WHERE wx_id = ? AND chat_model = ? " - "ORDER BY " - "CASE WHEN is_chating THEN 1 ELSE 0 END DESC, " - "chat_talk_time DESC LIMIT 20 " - ) - result = sqlm.fetch_all(sql, (wx_id, model)) - chat_info_list = [] - for chat in result: - chat_info_list.append( - ChatInfo( - chat_id=chat[0], - wx_id=chat[1], - chat_created_time=chat[2], - chat_talk_time=chat[3], - chat_topic=chat[4], - chat_model=chat[5], - is_chating=chat[6], + def _list_chat_info(person: Person, model: str) -> List: + """ + 列出用户的所有对话记录 + """ + # 按照 chat_talk_time 字段倒序排序,取前20个 + with make_db_session() as session: + chat_info_list = ( + session.query(DbGptChatInfo) + .filter_by(person_id=person.id, model=model) + .order_by( + DbGptChatInfo.is_chatting.desc(), + DbGptChatInfo.talk_time.desc(), ) + .limit(20) + .all() ) - return chat_info_list + _chat_info_list = [] + for chat_info in chat_info_list: + _chat_info_list.append(chat_info.to_model()) + return _chat_info_list @staticmethod - def get_chat_list_str(wx_id: str, model: str) -> str: - """获取用户的所有对话记录""" - chat_info_list = CopilotGPT4._list_chat_info(wx_id, model) - chat_info_list_str = "✨===GPT4对话记录===✨\n" - if chat_info_list == []: + def get_chat_list_str(person: Person, model: str) -> str: + """ + 获取用户的所有对话记录 + :param person: 用户 + :param model: 模型 + :return: 对话记录 + """ + chat_info_list = CopilotGPT4._list_chat_info(person, model) + chat_info_list_str = f"✨==={model}对话记录===✨\n" + if not chat_info_list: chat_info_list_str += " 📭 无对话记录" return chat_info_list_str - for i, chat in enumerate(chat_info_list): - if chat.is_chating: - chat_info_list_str += f"{i+1}. 💬{chat.chat_topic}\n" - else: - chat_info_list_str += f"{i+1}. {chat.chat_topic}\n" - return chat_info_list_str - - @staticmethod - def _update_chat(chat_info: ChatInfo, newconv: List = []) -> None: - """保存对话记录 - :param chat_info: 对话记录数据 - :param newconv: 新增对话记录 - """ - # 对话记录格式 - chat_info.chat_talk_time = get_current_timestamp() - sqlm = SqliteManager() - chat_info_dict = chat_info.dict - # 更新对话元数据 - sqlm.update( - "copilot_chats", - chat_info_dict, - f"chat_id = {chat_info.chat_id}", - ) - # 插入对话记录 - for conv in newconv: - sqlm.insert( - "chat_conversations", - { - "chat_id": chat_info.chat_id, - "conversation_role": conv["role"], - "conversation_content": conv["content"], - "conversation_timestamp": get_current_timestamp(), - }, - ) + with make_db_session() as session: + for i, chat_info in enumerate(chat_info_list): + chat = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + if chat.is_chatting: + chat_info_list_str += f"{i + 1}. 💬{chat.topic}\n" + else: + chat_info_list_str += f"{i + 1}. {chat.topic}\n" + return chat_info_list_str @staticmethod - def get_chat_info(wx_id: str, model: str, chat_index: int) -> Union[ChatInfo, None]: - """获取用户的对话信息""" - chat_index = chat_index - 1 - sql = ( - "SELECT chat_id, wx_id, chat_created_time, chat_talk_time, chat_topic, chat_model, is_chating " - "FROM copilot_chats " - "WHERE wx_id = ? AND chat_model = ? " - "ORDER BY " - "CASE WHEN is_chating THEN 1 ELSE 0 END DESC, " - "chat_talk_time DESC LIMIT 20 " - ) - sqlm = SqliteManager() - result = sqlm.fetch_all(sql, (wx_id, model)) - if result == []: + def get_chat_info( + person: Person, model: str, chat_index: int + ) -> Union[GptChatInfo, None]: + """ + 获取用户的对话信息 + :param person: 用户 + :param model: 模型 + :param chat_index: 对话记录索引(从1开始) + :return: 对话信息 + """ + chat_info_id_list = CopilotGPT4._list_chat_info(person, model) + if not chat_info_id_list: return None - if len(result) <= chat_index: + if chat_index <= 0 or chat_index > len(chat_info_id_list): return None - chat = result[chat_index] - # 获取对话记录 - conv = CopilotGPT4._get_chat_conversations(chat[0]) - chat_info = ChatInfo( - chat_id=chat[0], - wx_id=chat[1], - chat_created_time=chat[2], - chat_talk_time=chat[3], - chat_topic=chat[4], - chat_model=chat[5], - conversations=conv, - is_chating=chat[6], - ) - return chat_info + return chat_info_id_list[chat_index - 1] @staticmethod - def _get_chat_conversations(chat_id: int) -> List: - sql = ( - "SELECT conversation_role, conversation_content, conversation_timestamp FROM chat_conversations " - "WHERE chat_id = ?" - ) - sqlm = SqliteManager() - result = sqlm.fetch_all(sql, (chat_id,)) - conversations = [] - for conv in result: - conversations.append( - { - "role": conv[0], - "content": conv[1], - "timestamp": conv[2], - }, + def get_chatting_chat_info(person: Person, model: str) -> Union[GptChatInfo, None]: + """ + 获取正在进行中的对话信息 + :param person: 用户 + :param model: 模型 + :return: 对话信息 + """ + with make_db_session() as session: + chat_info = ( + session.query(DbGptChatInfo) + .filter_by(person_id=person.id, model=model, is_chatting=True) + .first() ) - return conversations + if not chat_info: + return None + return chat_info.to_model() @staticmethod - def get_chating_chat_info(wx_id: str, model: str) -> Union[ChatInfo, None]: - """获取正在进行中的对话信息""" - # 获取对话元信息 - sql = ( - "SELECT chat_id, wx_id, chat_created_time, chat_talk_time, chat_topic, chat_model, is_chating " - "FROM copilot_chats " - "WHERE wx_id = ? AND is_chating = TRUE AND chat_model = ? " - ) - sqlm = SqliteManager() - meta_info = sqlm.fetch_one(sql, (wx_id, model)) - if meta_info is None: - return None - # 获取对话记录 - conv = CopilotGPT4._get_chat_conversations(meta_info[0]) - return ChatInfo( - chat_id=meta_info[0], - wx_id=meta_info[1], - chat_created_time=meta_info[2], - chat_talk_time=meta_info[3], - chat_topic=meta_info[4], - chat_model=meta_info[5], - conversations=conv, - is_chating=meta_info[6], - ) - - @staticmethod - def chat(chat_info: ChatInfo, message: str) -> str: - """使用 Copilot-GPT4-Server 持续对话""" + def chat(chat_info: GptChatInfo, message: str, message_obj) -> str: + """ + 持续对话 + :param chat_info: 对话信息 + :param message: 用户消息 + :param message_obj: 消息对象 + :return: GPT 回复 + """ # 对外暴露的对话方法,必须保存对话记录 - return CopilotGPT4._chat(chat_info=chat_info, message=message, is_save=True) + return CopilotGPT4._chat( + chat_info=chat_info, message=message, message_obj=message_obj, is_save=True + ) @staticmethod - def _chat(chat_info: ChatInfo, message: str, is_save: bool = True) -> str: - """使用 Copilot-GPT4-Server 持续对话 + def _chat(chat_info: GptChatInfo, message: str, message_obj, is_save: bool) -> str: + """ + 持续对话 + :param chat_info: 对话信息 :param message: 用户消息 + :param message_obj: 消息对象 :param is_save: 是否保存此轮对话记录 + :return: GPT 回复 """ - newconv = [] - conversations = chat_info.conversations.copy() - # 将conversation 内字典的所有 timestamp 字段删除 - for conv in conversations: - if "timestamp" in conv: - conv.pop("timestamp") - newconv.append({"role": "user", "content": message}) - + newconv = [{"role": "user", "content": message}] # 发送请求 headers = { "Authorization": CopilotGPT4.bearer_token, "Content-Type": "application/json", } - json = {"model": chat_info.chat_model, "messages": conversations + newconv} + json = { + "model": chat_info.model, + "messages": DEFAULT_CONVERSATION + chat_info.get_conversation() + newconv, + } r_json = post_request_json( url=CopilotGPT4.api, headers=headers, json=json, timeout=60 ) @@ -532,40 +440,73 @@ def _chat(chat_info: ChatInfo, message: str, is_save: bool = True) -> str: msg = r_json["choices"][0]["message"] msg_content = msg.get("content", "调用Copilot-GPT4-Server失败") # 将返回的 assistant 回复添加到对话记录中 - if is_save: + if is_save is True: newconv.append({"role": "assistant", "content": msg_content}) - chat_info.conversations.extend(newconv) - CopilotGPT4._update_chat(chat_info, newconv) + chat_info.extend_conversation(newconv) + with make_db_session() as session: + _chat_info = ( + session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + ) + _chat_info.talk_time = datetime.now() + for chat_message in chat_info.gpt_chat_messages[-len(newconv) // 2 :]: + _chat_message = DbGptChatMessage.from_model(chat_message) + _chat_message.message_id = message_obj.id + _chat_info.gpt_chat_messages.append(_chat_message) + session.commit() return msg_content @staticmethod - def _save_chating_chat_topic(wx_id: str, model: str) -> None: - """生成正在进行的对话的主题""" - chat_info = CopilotGPT4.get_chating_chat_info(wx_id, model) - if chat_info is None or chat_info.has_topic: + def _save_chatting_chat_topic(person: Person, model: str) -> None: + """ + 生成正在进行的对话的主题 + """ + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) + if chat_info is None or CopilotGPT4._has_topic(chat_info): return # 生成对话主题 - if not CopilotGPT4.is_chat_valid(chat_info): + if not CopilotGPT4._is_chat_valid(chat_info): logger.error("对话记录长度小于1") return topic = CopilotGPT4._generate_chat_topic(chat_info) - if topic == "": + if not topic: logger.error("生成对话主题失败") raise ValueError("生成对话主题失败") # 更新对话主题 - chat_info.chat_topic = topic - CopilotGPT4._update_chat(chat_info) + with make_db_session() as session: + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + chat_info.topic = topic + session.commit() @staticmethod - def _generate_chat_topic(chat_info: ChatInfo) -> str: - """生成对话主题,用于保存对话记录""" - assert CopilotGPT4.is_chat_valid(chat_info) + def _generate_chat_topic(chat_info: GptChatInfo) -> str: + """ + 生成对话主题,用于保存对话记录 + """ + assert CopilotGPT4._is_chat_valid(chat_info) # 通过一次对话生成对话主题,但这次对话不保存到对话记录中 prompt = "请用10个字以内总结一下这次对话的主题,不带任何标点符号" - topic = CopilotGPT4._chat(chat_info=chat_info, message=prompt, is_save=False) + topic = CopilotGPT4._chat( + chat_info=chat_info, message=prompt, message_obj=None, is_save=False + ) # 限制主题长度 if len(topic) > 21: topic = topic[:21] + "..." logger.info(f"生成对话主题:{topic}") return topic + + @staticmethod + def _has_topic(chat_info: GptChatInfo) -> bool: + """ + 判断对话是否有主题 + """ + return chat_info.topic != DEFAULT_TOPIC + + @staticmethod + def _is_chat_valid(chat_info: GptChatInfo) -> bool: + """ + 判断对话是否有效 + """ + if chat_info.gpt_chat_messages: + return True + return False diff --git a/wechatter/commands/_commands/douyin_hot.py b/wechatter/commands/_commands/douyin_hot.py index 301e031..9e07958 100644 --- a/wechatter/commands/_commands/douyin_hot.py +++ b/wechatter/commands/_commands/douyin_hot.py @@ -3,7 +3,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request_json diff --git a/wechatter/commands/_commands/food_calories.py b/wechatter/commands/_commands/food_calories.py index 274e45e..5e82d75 100644 --- a/wechatter/commands/_commands/food_calories.py +++ b/wechatter/commands/_commands/food_calories.py @@ -7,7 +7,7 @@ from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request diff --git a/wechatter/commands/_commands/gasoline_price.py b/wechatter/commands/_commands/gasoline_price.py index 9bd4693..af654b6 100644 --- a/wechatter/commands/_commands/gasoline_price.py +++ b/wechatter/commands/_commands/gasoline_price.py @@ -5,7 +5,7 @@ import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request, load_json diff --git a/wechatter/commands/_commands/github_trending.py b/wechatter/commands/_commands/github_trending.py index f5a12d3..8360e17 100644 --- a/wechatter/commands/_commands/github_trending.py +++ b/wechatter/commands/_commands/github_trending.py @@ -6,7 +6,7 @@ from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request diff --git a/wechatter/commands/_commands/help.py b/wechatter/commands/_commands/help.py index 596e69f..1826312 100644 --- a/wechatter/commands/_commands/help.py +++ b/wechatter/commands/_commands/help.py @@ -2,7 +2,7 @@ import wechatter.config as config from wechatter.commands import commands from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils.text_to_image import text_to_image diff --git a/wechatter/commands/_commands/pai_post.py b/wechatter/commands/_commands/pai_post.py index 759b18d..36021e9 100644 --- a/wechatter/commands/_commands/pai_post.py +++ b/wechatter/commands/_commands/pai_post.py @@ -6,7 +6,7 @@ from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request diff --git a/wechatter/commands/_commands/people_daily.py b/wechatter/commands/_commands/people_daily.py index 1c348ee..a5581b3 100644 --- a/wechatter/commands/_commands/people_daily.py +++ b/wechatter/commands/_commands/people_daily.py @@ -1,7 +1,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils.time import get_current_ymd diff --git a/wechatter/commands/_commands/qrcode.py b/wechatter/commands/_commands/qrcode.py index c2ffaa2..0493add 100644 --- a/wechatter/commands/_commands/qrcode.py +++ b/wechatter/commands/_commands/qrcode.py @@ -5,7 +5,7 @@ import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils.time import get_current_datetime diff --git a/wechatter/commands/_commands/today_in_history.py b/wechatter/commands/_commands/today_in_history.py index 7f7acea..6b83890 100644 --- a/wechatter/commands/_commands/today_in_history.py +++ b/wechatter/commands/_commands/today_in_history.py @@ -3,7 +3,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request_json diff --git a/wechatter/commands/_commands/todo.py b/wechatter/commands/_commands/todo.py index 3d3d4be..b99c313 100644 --- a/wechatter/commands/_commands/todo.py +++ b/wechatter/commands/_commands/todo.py @@ -6,7 +6,7 @@ import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import load_json, save_json diff --git a/wechatter/commands/_commands/translate.py b/wechatter/commands/_commands/translate.py index 4eb3ad6..e68234e 100644 --- a/wechatter/commands/_commands/translate.py +++ b/wechatter/commands/_commands/translate.py @@ -6,7 +6,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request, get_request_json diff --git a/wechatter/commands/_commands/trivia.py b/wechatter/commands/_commands/trivia.py index f493299..3d86ba4 100644 --- a/wechatter/commands/_commands/trivia.py +++ b/wechatter/commands/_commands/trivia.py @@ -7,7 +7,7 @@ from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request diff --git a/wechatter/commands/_commands/weather.py b/wechatter/commands/_commands/weather.py index 18d261b..d863bfe 100644 --- a/wechatter/commands/_commands/weather.py +++ b/wechatter/commands/_commands/weather.py @@ -8,7 +8,7 @@ import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command from wechatter.exceptions import Bs4ParsingError -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request, load_json from wechatter.utils.time import get_current_hour, get_current_minute, get_current_ymdh diff --git a/wechatter/commands/_commands/weibo_hot.py b/wechatter/commands/_commands/weibo_hot.py index 8670ad0..ce5e318 100644 --- a/wechatter/commands/_commands/weibo_hot.py +++ b/wechatter/commands/_commands/weibo_hot.py @@ -3,7 +3,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request_json diff --git a/wechatter/commands/_commands/zhihu_hot.py b/wechatter/commands/_commands/zhihu_hot.py index 0d7a1c9..90435af 100644 --- a/wechatter/commands/_commands/zhihu_hot.py +++ b/wechatter/commands/_commands/zhihu_hot.py @@ -3,7 +3,7 @@ from loguru import logger from wechatter.commands.handlers import command -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender from wechatter.utils import get_request_json diff --git a/wechatter/commands/handlers.py b/wechatter/commands/handlers.py index ff19882..bfd5226 100644 --- a/wechatter/commands/handlers.py +++ b/wechatter/commands/handlers.py @@ -1,5 +1,8 @@ +import inspect from typing import List +from loguru import logger + import wechatter.config as config commands = {} @@ -18,6 +21,35 @@ def command(command: str, keys: List[str], desc: str): """ def decorator(func): + sig = inspect.signature(func) + params = sig.parameters + if len(params) < 2: + error_message = f"缺少命令处理函数参数,命令处理函数至少需要 to 和 message 参数:{func.__name__}" + logger.error(error_message) + raise ValueError(error_message) + if "to" not in params: + error_message = ( + f"参数名错误,命令处理函数的第1个参数必须为 to:{func.__name__}" + ) + logger.error(error_message) + raise ValueError(error_message) + if "message" not in params: + error_message = ( + f"参数名错误,命令处理函数的第2个参数必须为 message:{func.__name__}" + ) + logger.error(error_message) + raise ValueError(error_message) + if len(params) == 3 and "message_obj" not in params: + error_message = ( + f"参数名错误,命令处理函数的第3个参数必须为 message_obj{func.__name__}" + ) + logger.error(error_message) + raise ValueError(error_message) + if len(params) > 3: + error_message = f"参数数量错误,命令处理函数参数数量不能超过3个(to, message, message_obj){func.__name__}" + logger.error(error_message) + raise ValueError(error_message) + commands[command] = {} # 自定义命令关键词 if config.custom_command_key_dict.get(command, None): @@ -26,6 +58,7 @@ def decorator(func): commands[command]["keys"] = keys commands[command]["desc"] = desc commands[command]["handler"] = func + commands[command]["param_count"] = len(params) return func diff --git a/wechatter/config/parsers/__init__.py b/wechatter/config/parsers/__init__.py index 4400005..830b881 100644 --- a/wechatter/config/parsers/__init__.py +++ b/wechatter/config/parsers/__init__.py @@ -1,4 +1,4 @@ -from .weather_cron import parse_weather_cron_rule_list from .gasoline_price_cron import parse_gasoline_price_cron_rule_list +from .weather_cron import parse_weather_cron_rule_list __all__ = ["parse_weather_cron_rule_list", "parse_gasoline_price_cron_rule_list"] diff --git a/wechatter/database/__init__.py b/wechatter/database/__init__.py new file mode 100644 index 0000000..57f6e28 --- /dev/null +++ b/wechatter/database/__init__.py @@ -0,0 +1,17 @@ +from .database import create_tables, make_db_session +from .tables import person_group_relation # noqa +from .tables.gpt_chat_info import GptChatInfo +from .tables.gpt_chat_message import GptChatMessage +from .tables.group import Group +from .tables.message import Message +from .tables.person import Person + +__all__ = [ + "make_db_session", + "create_tables", + "GptChatInfo", + "GptChatMessage", + "Message", + "Group", + "Person", +] diff --git a/wechatter/database/database.py b/wechatter/database/database.py new file mode 100644 index 0000000..23e201c --- /dev/null +++ b/wechatter/database/database.py @@ -0,0 +1,18 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from wechatter.database.tables import Base +from wechatter.utils.path_manager import get_abs_path + +DB_PATH = get_abs_path("data/wechatter.sqlite") + +engine = create_engine(f"sqlite:///{DB_PATH}") + + +# 创建数据库会话函数,在上下文管理器中使用 +make_db_session = sessionmaker(engine) + + +# 创建数据库表 +def create_tables(): + Base.metadata.create_all(engine, checkfirst=True) diff --git a/wechatter/database/tables/__init__.py b/wechatter/database/tables/__init__.py new file mode 100644 index 0000000..a14648b --- /dev/null +++ b/wechatter/database/tables/__init__.py @@ -0,0 +1,8 @@ +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass + + +__all__ = ["Base"] diff --git a/wechatter/database/tables/gpt_chat_info.py b/wechatter/database/tables/gpt_chat_info.py new file mode 100644 index 0000000..60a4c9d --- /dev/null +++ b/wechatter/database/tables/gpt_chat_info.py @@ -0,0 +1,83 @@ +from datetime import datetime +from typing import TYPE_CHECKING, List + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from wechatter.database.tables import Base +from wechatter.database.tables.gpt_chat_message import GptChatMessage +from wechatter.models.gpt import ( + GptChatInfo as GptChatInfoModel, + GptChatMessage as GptChatMessageModel, +) + +if TYPE_CHECKING: + from wechatter.database.tables.person import Person + + +class GptChatInfo(Base): + """ + GPT聊天表 + """ + + __tablename__ = "gpt_chat_info" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + person_id: Mapped[str] = mapped_column(String, ForeignKey("person.id")) + topic: Mapped[str] + model: Mapped[str] + created_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=datetime.now() + ) + talk_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), onupdate=datetime.now() + ) + is_chatting: Mapped[bool] = mapped_column(Boolean, default=True) + + person: Mapped["Person"] = relationship("Person", back_populates="gpt_chat_infos") + gpt_chat_messages: Mapped[List["GptChatMessage"]] = relationship( + "GptChatMessage", back_populates="gpt_chat_info" + ) + + @classmethod + def from_model(cls, gpt_chat_info_model: GptChatInfoModel): + gpt_chat_messages = [] + for message in gpt_chat_info_model.gpt_chat_messages: + gpt_chat_messages.append(GptChatMessage.from_model(message)) + + return cls( + id=gpt_chat_info_model.id, + person_id=gpt_chat_info_model.person.id, + topic=gpt_chat_info_model.topic, + model=gpt_chat_info_model.model, + created_time=gpt_chat_info_model.created_time, + talk_time=gpt_chat_info_model.talk_time, + is_chatting=gpt_chat_info_model.is_chatting, + gpt_chat_messages=gpt_chat_messages, + ) + + def to_model(self) -> GptChatInfoModel: + gpt_chat_info = GptChatInfoModel( + id=self.id, + person=self.person.to_model(), + topic=self.topic, + model=self.model, + created_time=self.created_time, + talk_time=self.talk_time, + is_chatting=self.is_chatting, + ) + + gpt_chat_messages = [] + for message in self.gpt_chat_messages: + gpt_chat_messages.append( + GptChatMessageModel( + id=message.id, + message=message.message.to_model(), + gpt_chat_info=gpt_chat_info, + gpt_response=message.gpt_response, + # role=message.role.value, + ) + ) + gpt_chat_info.gpt_chat_messages = gpt_chat_messages + + return gpt_chat_info diff --git a/wechatter/database/tables/gpt_chat_message.py b/wechatter/database/tables/gpt_chat_message.py new file mode 100644 index 0000000..5062a90 --- /dev/null +++ b/wechatter/database/tables/gpt_chat_message.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from wechatter.database.tables import Base +from wechatter.models.gpt import GptChatMessage as GptChatMessageModel + +if TYPE_CHECKING: + from wechatter.database.tables.gpt_chat_info import GptChatInfo + from wechatter.database.tables.message import Message + + +class GptChatMessage(Base): + """ + GPT对话消息表 + """ + + __tablename__ = "gpt_chat_message" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + message_id: Mapped[int] = mapped_column( + Integer, ForeignKey("message.id"), unique=True + ) + gpt_chat_id: Mapped[int] = mapped_column(Integer, ForeignKey("gpt_chat_info.id")) + gpt_response: Mapped[str] + + message: Mapped["Message"] = relationship( + "Message", back_populates="gpt_chat_message" + ) + gpt_chat_info: Mapped["GptChatInfo"] = relationship( + "GptChatInfo", back_populates="gpt_chat_messages" + ) + + @classmethod + def from_model(cls, gpt_chat_message_model: GptChatMessageModel): + return cls( + message_id=gpt_chat_message_model.message.id, + gpt_chat_id=gpt_chat_message_model.gpt_chat_info.id, + gpt_response=gpt_chat_message_model.gpt_response, + ) + + def to_model(self) -> GptChatMessageModel: + return GptChatMessageModel( + id=self.id, + message=self.message.to_model(), + gpt_chat_info=self.gpt_chat_info.to_model(), + gp_response=self.gpt_response, + ) diff --git a/wechatter/database/tables/group.py b/wechatter/database/tables/group.py new file mode 100644 index 0000000..f7b2c14 --- /dev/null +++ b/wechatter/database/tables/group.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, List, Union + +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from wechatter.database.tables import Base +from wechatter.models.wechat import Group as GroupModel + +if TYPE_CHECKING: + from wechatter.database.tables.message import Message + from wechatter.database.tables.person import Person + + +class Group(Base): + """ + 微信群表 + """ + + __tablename__ = "group" + + id: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] + alias: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + + members: Mapped[List["Person"]] = relationship( + "Person", + secondary="person_group_relation", + back_populates="groups", + ) + messages: Mapped[List["Message"]] = relationship("Message", back_populates="group") + + @classmethod + def from_model(cls, group_model: GroupModel): + return cls( + id=group_model.id, + name=group_model.name, + ) + + def to_model(self) -> GroupModel: + member_list = [] + for member in self.members: + member_list.append(member.to_model()) + return GroupModel( + id=self.id, + name=self.name, + member_list=member_list, + ) + + def update(self, group_model: GroupModel): + self.name = group_model.name + member_list = [] + for member in self.members: + member_list.append(Person.from_member_model(member)) + self.members = member_list diff --git a/wechatter/database/tables/message.py b/wechatter/database/tables/message.py new file mode 100644 index 0000000..367fc8d --- /dev/null +++ b/wechatter/database/tables/message.py @@ -0,0 +1,68 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Union + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from wechatter.database.tables import Base +from wechatter.database.tables.gpt_chat_message import GptChatMessage +from wechatter.models.wechat import Message as MessageModel, MessageType + +if TYPE_CHECKING: + from wechatter.database.tables.group import Group + from wechatter.database.tables.person import Person + + +class Message(Base): + """ + 消息表 + """ + + __tablename__ = "message" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + person_id: Mapped[str] = mapped_column(String, ForeignKey("person.id")) + group_id: Mapped[Union[str, None]] = mapped_column( + String, ForeignKey("group.id"), nullable=True + ) + type: Mapped[MessageType] + content: Mapped[str] + created_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + is_mentioned: Mapped[bool] = mapped_column(Boolean, default=False) + is_quoted: Mapped[bool] = mapped_column(Boolean, default=False) + + person: Mapped["Person"] = relationship("Person", back_populates="messages") + group: Mapped[Union["Group", None]] = relationship( + "Group", back_populates="messages" + ) + gpt_chat_message: Mapped[Union["GptChatMessage", None]] = relationship( + "GptChatMessage", back_populates="message", uselist=False + ) + + @classmethod + def from_model(cls, message_model: MessageModel): + group_id = None + if message_model.is_group: + group_id = message_model.group.id + + return cls( + id=message_model.id, + person_id=message_model.person.id, + group_id=group_id, + type=message_model.type.value, + content=message_model.content, + is_mentioned=message_model.is_mentioned, + is_quoted=message_model.is_quoted, + ) + + def to_model(self) -> MessageModel: + return MessageModel( + id=self.id, + type=self.type, + person=self.person.to_model(), + group=self.group.to_model() if self.group else None, + content=self.content, + is_mentioned=self.is_mentioned, + ) diff --git a/wechatter/database/tables/person.py b/wechatter/database/tables/person.py new file mode 100644 index 0000000..71d9b8d --- /dev/null +++ b/wechatter/database/tables/person.py @@ -0,0 +1,83 @@ +from typing import TYPE_CHECKING, List, Union + +from sqlalchemy import Boolean, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from wechatter.database.tables import Base +from wechatter.models.wechat import Gender, GroupMember, Person as PersonModel + +if TYPE_CHECKING: + from wechatter.database.tables.gpt_chat_info import GptChatInfo + from wechatter.database.tables.group import Group + from wechatter.database.tables.message import Message + + +class Person(Base): + """ + 微信用户表 + """ + + __tablename__ = "person" + + id: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] + alias: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + gender: Mapped[Union[Gender, None]] = mapped_column(String, nullable=True) + province: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + city: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + # phone: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + is_star: Mapped[bool] = mapped_column(Boolean, default=False) + is_friend: Mapped[bool] = mapped_column(Boolean, default=False) + + groups: Mapped[List["Group"]] = relationship( + "Group", + secondary="person_group_relation", + back_populates="members", + ) + messages: Mapped[List["Message"]] = relationship("Message", back_populates="person") + gpt_chat_infos: Mapped[List["GptChatInfo"]] = relationship( + "GptChatInfo", back_populates="person" + ) + + @classmethod + def from_model(cls, person_model: PersonModel): + return cls( + id=person_model.id, + name=person_model.name, + alias=person_model.alias, + gender=person_model.gender.value, + province=person_model.province, + city=person_model.city, + is_star=person_model.is_star, + is_friend=person_model.is_friend, + ) + + @classmethod + def from_member_model(cls, member_model: "GroupMember"): + return cls( + id=member_model.id, + name=member_model.name, + alias=member_model.alias, + ) + + def to_model(self) -> PersonModel: + return PersonModel( + id=self.id, + name=self.name, + alias=self.alias, + gender=self.gender, + signature="", + province=self.province, + city=self.city, + is_star=self.is_star, + is_friend=self.is_friend, + ) + + def update(self, person_model: PersonModel): + self.name = person_model.name + self.alias = person_model.alias + self.gender = person_model.gender.value + self.province = person_model.province + self.city = person_model.city + self.is_star = person_model.is_star + self.is_friend = person_model.is_friend diff --git a/wechatter/database/tables/person_group_relation.py b/wechatter/database/tables/person_group_relation.py new file mode 100644 index 0000000..c33afae --- /dev/null +++ b/wechatter/database/tables/person_group_relation.py @@ -0,0 +1,19 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from wechatter.database.tables import Base + + +class PersonGroupRelation(Base): + """ + 用户和群组的关系表 + """ + + __tablename__ = "person_group_relation" + + person_id: Mapped[str] = mapped_column( + String, ForeignKey("person.id"), primary_key=True + ) + group_id: Mapped[str] = mapped_column( + String, ForeignKey("group.id"), primary_key=True + ) diff --git a/wechatter/message/message_parser.py b/wechatter/message/message_parser.py index 3761ba6..963ccf6 100644 --- a/wechatter/message/message_parser.py +++ b/wechatter/message/message_parser.py @@ -5,17 +5,21 @@ import wechatter.config as config from wechatter.bot.bot_info import BotInfo -from wechatter.models.message import Message, SendTo +from wechatter.models.wechat import Message, SendTo class MessageHandler: - """消息处理器,用于处理用户发来的消息""" + """ + 消息处理器,用于处理用户发来的消息 + """ def __init__(self, commands: dict): self.__commands = commands def handle_message(self, message: Message) -> None: - """处理消息""" + """ + 处理消息 + """ # 解析命令 content = message.content # 消息内容 # 消息内容格式: / @@ -37,24 +41,37 @@ def handle_message(self, message: Message) -> None: logger.debug("该消息为群消息,但未@机器人,不处理") return - to = SendTo.from_message_source(message.source) + to = SendTo(person=message.person, group=message.group) # 是命令消息 # 开始处理命令 cmd_handler = cmd_dict["handler"] if cmd_handler is not None: - cmd_handler(to, cmd_dict["arg"]) + if cmd_dict["param_count"] == 2: + cmd_handler( + to=to, + message=cmd_dict["arg"], + ) + elif cmd_dict["param_count"] == 3: + cmd_handler( + to=to, + message=cmd_dict["arg"], + message_obj=message, + ) else: logger.error("该命令未实现") return def parse_command(self, content: str, is_mentioned: bool, is_group: bool) -> dict: - """解析命令""" + """ + 解析命令 + """ cmd_dict = { "command": "None", "desc": "", "arg": "", "handler": None, + "param_count": 0, } # 不带命令前缀和@前缀的消息内容 if is_mentioned and is_group: @@ -71,6 +88,7 @@ def parse_command(self, content: str, is_mentioned: bool, is_group: bool) -> dic cmd_dict["command"] = command cmd_dict["desc"] = info["desc"] cmd_dict["handler"] = info["handler"] + cmd_dict["param_count"] = info["param_count"] if len(cont_list) == 2: cmd_dict["arg"] = cont_list[1] # 消息内容 return cmd_dict diff --git a/wechatter/message_forwarder/message_forwarder.py b/wechatter/message_forwarder/message_forwarder.py index 6ef8cd6..f4e686f 100644 --- a/wechatter/message_forwarder/message_forwarder.py +++ b/wechatter/message_forwarder/message_forwarder.py @@ -2,7 +2,7 @@ from loguru import logger -from wechatter.models.message import Message +from wechatter.models.wechat import Message from wechatter.sender import sender @@ -15,15 +15,8 @@ def __init__(self, rule_list: List): def forward_message(self, message: Message): """消息转发""" - # 判断消息来源 - from_name = "" - if message.is_group: - from_name = message.source.g_info.name - else: - from_name = message.source.p_info.name - # TODO: 转发文件 - + from_name = message.sender_name # 判断消息是否符合转发规则 for rule in self.rule_list: # 判断消息来源是否符合转发规则 @@ -31,7 +24,7 @@ def forward_message(self, message: Message): # 构造转发消息 msg = self.__construct_forwarding_message(message) logger.info( - f"转发消息:{from_name} -> {rule['to_persons']}\n" + f"转发消息:{from_name} -> {rule['to_persons']};" f"转发消息:{from_name} -> {rule['to_groups']}" ) sender.mass_send_msg(rule["to_persons"], msg) @@ -42,13 +35,13 @@ def __construct_forwarding_message(self, message: Message) -> str: content = message.content if message.is_group: content = ( - f"⤴️ {message.source.p_info.name}在{message.source.g_info.name}中说:\n" + f"⤴️ {message.person.name}在{message.group.name}中说:\n" f"-------------------------\n" f"{content}" ) else: content = ( - f"⤴️ {message.source.p_info.name}说:\n" + f"⤴️ {message.person.name}说:\n" f"-------------------------\n" f"{content}" ) diff --git a/wechatter/models/__init__.py b/wechatter/models/__init__.py index e69de29..4d8af86 100644 --- a/wechatter/models/__init__.py +++ b/wechatter/models/__init__.py @@ -0,0 +1,6 @@ +from .gpt import GptChatInfo, GptChatMessage # noqa: F401 +from .wechat import Message, Person # noqa: F401 + +GptChatInfo.model_rebuild() + +__all__ = [] diff --git a/wechatter/models/github/pr_webhook.py b/wechatter/models/github/pr_webhook.py index 390d4b5..d83cb17 100644 --- a/wechatter/models/github/pr_webhook.py +++ b/wechatter/models/github/pr_webhook.py @@ -34,11 +34,11 @@ class PullRequest(BaseModel): state: str title: str user: User - body: Optional[str] + body: Optional[str] = None base: PrBranch head: PrBranch merged: bool - merged_by: Optional[User] + merged_by: Optional[User] = None class GithubPrWebhook(BaseModel): diff --git a/wechatter/models/gpt/__init__.py b/wechatter/models/gpt/__init__.py new file mode 100644 index 0000000..2c96029 --- /dev/null +++ b/wechatter/models/gpt/__init__.py @@ -0,0 +1,4 @@ +from .gpt_chat_info import GptChatInfo +from .gpt_chat_message import GptChatMessage + +__all__ = ["GptChatMessage", "GptChatInfo"] diff --git a/wechatter/models/gpt/gpt_chat_info.py b/wechatter/models/gpt/gpt_chat_info.py new file mode 100644 index 0000000..b661b94 --- /dev/null +++ b/wechatter/models/gpt/gpt_chat_info.py @@ -0,0 +1,44 @@ +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional + +from pydantic import BaseModel + +from wechatter.models.gpt.gpt_chat_message import GptChatMessage +from wechatter.models.wechat import Message + +if TYPE_CHECKING: + from wechatter.models.wechat import Person + + +class GptChatInfo(BaseModel): + id: Optional[int] = None + person: "Person" + topic: str + model: str + created_time: datetime = datetime.now() + talk_time: datetime = datetime.now() + is_chatting: bool = True + gpt_chat_messages: List[GptChatMessage] = [] + + def get_conversation(self) -> List: + conversation = [] + for message in self.gpt_chat_messages: + conversation.extend(message.to_turn()) + return conversation + + def extend_conversation(self, conversation: List): + conv = [] + for i in range(0, len(conversation) - 1, 2): + conv.append( + GptChatMessage( + message=Message( + type="text", + person=self.person, + content=conversation[i]["content"], + ), + gpt_chat_info=self, + gpt_response=conversation[i + 1]["content"], + ) + ) + self.gpt_chat_messages.extend(conv) + return self diff --git a/wechatter/models/gpt/gpt_chat_message.py b/wechatter/models/gpt/gpt_chat_message.py new file mode 100644 index 0000000..b393b7b --- /dev/null +++ b/wechatter/models/gpt/gpt_chat_message.py @@ -0,0 +1,33 @@ +import enum +from typing import TYPE_CHECKING, Optional + +from pydantic import BaseModel + +if TYPE_CHECKING: + from wechatter.models.gpt.gpt_chat_info import GptChatInfo + from wechatter.models.wechat import Message + + +class GptChatRole(enum.Enum): + system = "system" + user = "user" + assistant = "assistant" + + +class GptChatMessage(BaseModel): + id: Optional[int] = None + message: "Message" + gpt_chat_info: "GptChatInfo" + gpt_response: str + + def to_turn(self): + return [ + { + "role": GptChatRole.user.value, + "content": self.message.content, + }, + { + "role": GptChatRole.assistant.value, + "content": self.gpt_response, + }, + ] diff --git a/wechatter/models/message/__init__.py b/wechatter/models/message/__init__.py deleted file mode 100644 index 8c1717f..0000000 --- a/wechatter/models/message/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .message import Message, MessageSource -from .send_message import SendTo - -__all__ = [ - "Message", - "MessageSource", - "SendTo", -] diff --git a/wechatter/models/message/group_info.py b/wechatter/models/message/group_info.py deleted file mode 100644 index 3db01d4..0000000 --- a/wechatter/models/message/group_info.py +++ /dev/null @@ -1,86 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - - -@dataclass -class GroupMemberInfo: - """群成员类""" - - id: str - name: str - alias: str = "" - - -@dataclass -class GroupInfo: - """群消息类""" - - id: str - name: str - admin_id_list: List[str] = field(default_factory=list) - member_list: List[GroupMemberInfo] = field(default_factory=list) - - -# class GroupMemberInfo: -# """群成员类""" - -# def __init__( -# self, -# id: str, -# name: str, -# alias: str = "", -# ): -# self.id = id -# self.name = name -# self.alias = alias - -# def __str__(self) -> str: -# return f"微信ID:{self.id}\n昵称:{self.name}\n备注:{self.alias}" - - -# class GroupInfo: -# """群消息类""" - -# def __init__( -# self, -# id: str, -# name: str, -# admin_id_list: List[str] = [], -# member_list: List[dict] = [], -# ): -# self.id = id -# self.name = name -# self.admin_id_list = admin_id_list -# self.member_list = member_list - -# @property -# def admin_id_list(self) -> List[str]: -# """获取管理员ID列表""" -# return self.__admin_id_list - -# @admin_id_list.setter -# def admin_id_list(self, admin_id_list: List[str]): -# self.__admin_id_list = admin_id_list - -# @property -# def member_list(self) -> List[GroupMemberInfo]: -# """获取群成员列表""" -# return self.__member_list - -# @member_list.setter -# def member_list(self, member_list: List[dict]): -# self.__member_list = [] -# for m in member_list: -# self.__member_list.append( -# GroupMemberInfo( -# id=m["id"], -# name=m["name"], -# alias=m["alias"], -# ) -# ) - -# def __str__(self) -> str: -# # 群成员数量过多,不打印 -# # member_list_str = "[\n{}\n]".format('\n'.join([str(m) for m in self.member_list])) -# # return f"群ID: {self.id}\n群名:{self.name}\n管理员:{self.admin_id_list}\n成员:{member_list_str}" -# return f"群ID: {self.id}\n群名:{self.name}\n管理员:{self.admin_id_list}\n成员:{str(self.member_list)}" diff --git a/wechatter/models/message/message.py b/wechatter/models/message/message.py deleted file mode 100644 index e20e6cb..0000000 --- a/wechatter/models/message/message.py +++ /dev/null @@ -1,209 +0,0 @@ -# 消息类 -import json -import re -from enum import Enum -from typing import Union - -from loguru import logger - -import wechatter.config as config -from wechatter.models.message.group_info import GroupInfo -from wechatter.models.message.person_info import PersonInfo - - -class MessageType(Enum): - """消息类型枚举""" - - TEXT = 0 - FILE = 1 - LINK = 2 - # TODO: 图片识别 - # IMAGE = 3 - - -class MessageSenderType(Enum): - """消息来源枚举""" - - PERSON = 0 - GROUP = 1 - # TODO: 公众号文章 - # ARTICLE = 3 - - -class MessageSource: - """消息来源类""" - - def __init__( - self, - p_info: PersonInfo, - g_info: Union[GroupInfo, None] = None, - ): - self.p_info = p_info - self.g_info = g_info - - def __str__(self) -> str: - result = "" - if self.g_info is not None: - result += str(self.g_info) - result += str(self.p_info) - return result - - -class Message: - """消息类 - :property content: 消息内容 - :property source: 消息来源 - :property is_mentioned: 是否@机器人 - :property is_quote: 是否引用机器人消息 - :property is_group: 是否是群消息 - """ - - def __init__( - self, - type: str, - content: str, - source: str, - is_mentioned: str = "0", - command: dict = {}, - ): - self.type = type - self.content = content - self.source: MessageSource = source - self.is_mentioned = is_mentioned - self.is_group = bool(self.source.g_info) - self.is_quote = content - self.command = command - - # 获取消息类型 - @property - def type(self) -> MessageType: - return self.__type - - @type.setter - def type(self, type: str): - if type == "text": - self.__type = MessageType.TEXT - elif type == "file": - self.__type = MessageType.FILE - elif type == "urlLink": - self.__type = MessageType.LINK - else: - raise ValueError("消息类型错误") - - @property - def content(self) -> str: - """获取消息内容""" - return self.__content - - @content.setter - def content(self, content: str): - # 对于 iPad、手机端的微信,@名称后面会跟着一个未解码的空格的Unicode编码:"@Cassius\u2005/help" - self.__content = content.replace("\u2005", " ", 1) - - @property - def msg(self) -> str: - """获取消息内容""" - return self.__msg - - @property - def source(self) -> MessageSource: - """获取消息来源""" - return self.__source - - @source.setter - def source(self, source_json_str: str): - if source_json_str == "": - self.__source = MessageSource() - return - # 解析json - source_json = dict() - try: - source_json = json.loads(source_json_str) - except json.JSONDecodeError as e: - logger.error("消息来源解析失败") - raise e - - # from为发送者信息,无论是个人消息还是群消息,都有from - payload = source_json.get("from").get("payload", {}) - if payload == {}: - self.__source = MessageSource() - return - id = payload.get("id", "") - name = payload.get("name", "") - alias = payload.get("alias", "") - gender = int(payload.get("gender", -1)) - signature = payload.get("signature", "") - province = payload.get("province", "") - city = payload.get("city", "") - phone_list = payload.get("phone", []) - is_star = payload.get("star", "") - message_source = MessageSource( - p_info=PersonInfo( - id=id, - name=name, - alias=alias, - gender=gender, - signature=signature, - province=province, - city=city, - phone_list=phone_list, - is_star=is_star, - ) - ) - - # room为群信息,只有群消息才有room - if source_json["room"] != "": - g_data = source_json["room"] - id = g_data["id"] - payload = g_data.get("payload", {}) - name = payload.get("topic", "") - admin_id_list = payload.get("adminIdList", []) - member_list = payload.get("memberList", []) - message_source.g_info = GroupInfo( - id=id, - name=name, - admin_id_list=admin_id_list, - member_list=member_list, - ) - self.__source = message_source - - @property - def is_mentioned(self) -> bool: - """是否@机器人""" - return self.__is_mentioned - - @is_mentioned.setter - def is_mentioned(self, is_mentioned: str): - if is_mentioned == "1": - self.__is_mentioned = True - else: - self.__is_mentioned = False - - @property - def is_group(self) -> bool: - """是否是群消息""" - return self.__is_group - - @is_group.setter - def is_group(self, is_group: bool): - self.__is_group = is_group - - @property - def is_quote(self) -> bool: - """是否引用机器人消息""" - return self.__is_quote - - @is_quote.setter - def is_quote(self, content: str): - self.__is_quote = False - # 判断是否为引用消息 - quote_pattern = ( - r"(?s)「(.*?)」\n- - - - - - - - - - - - - - -" # 引用消息的正则 - ) - match_result = re.match(quote_pattern, content) - # 判断是否为引用机器人消息 - if bool(match_result) and content.startswith(f"「{config.bot_name}"): - self.__is_quote = True - - def __str__(self) -> str: - return f"消息内容:{self.content}\n消息来源:\n{self.source}\n是否@:{self.is_mentioned}\n是否引用:{self.is_quote}" diff --git a/wechatter/models/message/person_info.py b/wechatter/models/message/person_info.py deleted file mode 100644 index 4ad5025..0000000 --- a/wechatter/models/message/person_info.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - - -@dataclass -class PersonInfo: - """个人消息类""" - - id: str - name: str - alias: str = "" - gender: int = -1 - signature: str = "" - province: str = "" - city: str = "" - phone_list: List[str] = field(default_factory=list) - is_star: bool = False diff --git a/wechatter/models/message/send_message.py b/wechatter/models/message/send_message.py deleted file mode 100644 index 0204b41..0000000 --- a/wechatter/models/message/send_message.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass - -from wechatter.models.message.message import MessageSource - - -@dataclass -class SendTo: - """发送对象类""" - - p_id: str - p_name: str - # p_alias: str - g_id: str - g_name: str - - @classmethod - def from_message_source(cls, source: MessageSource): - p_id = "" - p_name = "" - # p_alias = "" - g_id = "" - g_name = "" - if source.p_info is not None: - p_id = source.p_info.id - p_name = source.p_info.name - if source.g_info is not None: - g_id = source.g_info.id - g_name = source.g_info.name - return cls(p_id, p_name, g_id, g_name) diff --git a/wechatter/models/wechat/__init__.py b/wechatter/models/wechat/__init__.py new file mode 100644 index 0000000..571e673 --- /dev/null +++ b/wechatter/models/wechat/__init__.py @@ -0,0 +1,14 @@ +from .group import Group, GroupMember +from .message import Message, MessageType +from .person import Gender, Person +from .send_to import SendTo + +__all__ = [ + "Message", + "MessageType", + "SendTo", + "Group", + "GroupMember", + "Person", + "Gender", +] diff --git a/wechatter/models/wechat/group.py b/wechatter/models/wechat/group.py new file mode 100644 index 0000000..fd1b8a1 --- /dev/null +++ b/wechatter/models/wechat/group.py @@ -0,0 +1,25 @@ +from typing import List, Optional + +from pydantic import BaseModel + + +class GroupMember(BaseModel): + """ + 群成员类 + """ + + id: str + name: str + alias: str + + +class Group(BaseModel): + """ + 群消息类 + """ + + id: str + name: str + # alias: str 目前上游不支持 + admin_id_list: Optional[List[str]] = None + member_list: List[GroupMember] diff --git a/wechatter/models/wechat/message.py b/wechatter/models/wechat/message.py new file mode 100644 index 0000000..d9e771c --- /dev/null +++ b/wechatter/models/wechat/message.py @@ -0,0 +1,153 @@ +# 消息类 +import enum +import json +import re +from functools import cached_property +from typing import Optional + +from loguru import logger +from pydantic import BaseModel, computed_field + +import wechatter.config as config +from wechatter.models.wechat.group import Group +from wechatter.models.wechat.person import Person + + +class MessageType(enum.Enum): + """ + 消息类型枚举类 + """ + + text = "text" + file = "file" + urlLink = "urlLink" + friendship = "friendship" + + +class MessageSenderType(enum.Enum): + """ + 消息来源枚举 + """ + + PERSON = 0 + GROUP = 1 + # TODO: 公众号文章 + # ARTICLE = 2 + + +class Message(BaseModel): + """ + 微信消息类(消息接收) + """ + + type: MessageType + person: Person + group: Optional[Group] = None + content: str + is_mentioned: bool = False + id: Optional[int] = None + + @classmethod + def from_api_msg( + cls, + type: MessageType, + content: str, + source: str, + is_mentioned: str, + ): + """ + 从API接口创建消息对象 + """ + try: + source_json = json.loads(source) + except json.JSONDecodeError as e: + logger.error("消息来源解析失败") + raise e + + # from为发送者信息,无论是个人消息还是群消息,都有from + payload = source_json.get("from").get("payload", {}) + gender = int(payload.get("gender", -1)) + g = "unknown" + if gender == 1: + g = "male" + elif gender == 0: + g = "female" + _person = Person( + id=payload.get("id", ""), + name=payload.get("name", ""), + alias=payload.get("alias", ""), + gender=g, + signature=payload.get("signature", ""), + province=payload.get("province", ""), + city=payload.get("city", ""), + # phone_list=payload.get("phone", []), + is_star=payload.get("star", ""), + is_friend=payload.get("friend", ""), + ) + + _group = None + # room为群信息,只有群消息才有room + if source_json["room"] != "": + g_data = source_json["room"] + payload = g_data.get("payload", {}) + _group = Group( + id=g_data.get("id", ""), + name=payload.get("topic", ""), + admin_id_list=payload.get("adminIdList", []), + member_list=payload.get("memberList", []), + ) + _content = content.replace("\u2005", " ", 1) + _is_mentioned = False + if is_mentioned == "1": + _is_mentioned = True + return cls( + type=type, + person=_person, + group=_group, + content=_content, + is_mentioned=is_mentioned, + ) + + @computed_field + @property + def is_group(self) -> bool: + """ + 是否是群消息 + """ + return self.group is not None + + @computed_field + @cached_property + def is_quoted(self) -> bool: + """ + 是否引用机器人消息 + """ + # 引用消息的正则 + quote_pattern = r"(?s)「(.*?)」\n- - - - - - - - - - - - - - -" + match_result = re.match(quote_pattern, self.content) + # 判断是否为引用机器人消息 + if match_result and self.content.startswith(f"「{config.bot_name}"): + return True + return False + + # TODO: 判断所有的引用消息,不仅仅是机器人消息 + # 待解决:在群中如果有人设置了自己的群中名称,那么引用内容的名字会变化,导致无法匹配到用户 + + @computed_field + @property + def sender_name(self) -> str: + """ + 返回消息发送对象名,如果是群则返回群名,如果不是则返回人名 + """ + return self.group.name if self.is_group else self.person.name + + def __str__(self) -> str: + source = self.person + if self.is_group: + source = self.group + return ( + f"消息内容:{self.content}\n" + f"消息来源:{source}\n" + f"是否@:{self.is_mentioned}\n" + f"是否引用:{self.is_quoted}" + ) diff --git a/wechatter/models/wechat/person.py b/wechatter/models/wechat/person.py new file mode 100644 index 0000000..f7ebdab --- /dev/null +++ b/wechatter/models/wechat/person.py @@ -0,0 +1,30 @@ +import enum + +from pydantic import BaseModel + + +class Gender(enum.Enum): + """ + 性别类 + """ + + male = "male" + female = "female" + unknown = "unknown" + + +class Person(BaseModel): + """ + 个人消息类 + """ + + id: str + name: str + alias: str + gender: Gender + signature: str + province: str + city: str + # phone_list: List[str] + is_star: bool + is_friend: bool diff --git a/wechatter/models/wechat/send_to.py b/wechatter/models/wechat/send_to.py new file mode 100644 index 0000000..f5cc151 --- /dev/null +++ b/wechatter/models/wechat/send_to.py @@ -0,0 +1,49 @@ +from typing import Optional + +from loguru import logger +from pydantic import BaseModel, computed_field + +from wechatter.models.wechat.group import Group +from wechatter.models.wechat.person import Person + + +class SendTo(BaseModel): + """ + 发送对象类 + """ + + person: Person + group: Optional[Group] = None + + @computed_field + @property + def p_id(self) -> str: + return self.person.id + + @computed_field + @property + def p_name(self) -> str: + return self.person.name + + @computed_field + @property + def p_alias(self) -> str: + return self.person.alias + + @computed_field + @property + def g_id(self) -> Optional[str]: + try: + return self.group.id + except AttributeError: + logger.warning("此发送对象不是群聊") + return None + + @computed_field + @property + def g_name(self) -> Optional[str]: + try: + return self.group.name + except AttributeError: + logger.warning("此发送对象不是群聊") + return None diff --git a/wechatter/sender/notifier.py b/wechatter/sender/notifier.py index 37aab05..d9c4e56 100644 --- a/wechatter/sender/notifier.py +++ b/wechatter/sender/notifier.py @@ -1,5 +1,5 @@ # 消息通知器 -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo from wechatter.sender import sender diff --git a/wechatter/sender/quotation.py b/wechatter/sender/quotation.py new file mode 100644 index 0000000..6e9be8b --- /dev/null +++ b/wechatter/sender/quotation.py @@ -0,0 +1,20 @@ +QUOTABLE_FORMAT = "(可引用:%s)\n" + + +# 将消息可引用化 +def make_quotable(message: str) -> str: + """ + 将消息可引用化 + :param message: 消息内容 + :return: 可引用的消息内容 + """ + # 获取可引用消息的ID(可引用标识符) + quotable_id = _get_random_quotable_id() + return QUOTABLE_FORMAT % quotable_id + message + + +def _get_random_quotable_id() -> str: + """ + 获取可引用消息的ID + :return: 可引用消息的ID + """ diff --git a/wechatter/sender/sender.py b/wechatter/sender/sender.py index 9539a01..dbcde6c 100644 --- a/wechatter/sender/sender.py +++ b/wechatter/sender/sender.py @@ -1,6 +1,6 @@ import json from functools import singledispatch -from typing import List +from typing import List, Union import requests import tenacity @@ -8,7 +8,7 @@ import wechatter.config as config import wechatter.utils.http_request as http_request -from wechatter.models.message import SendTo +from wechatter.models.wechat import SendTo # 对retry装饰器重新包装,增加日志输出 @@ -42,12 +42,13 @@ def _post_request( ) +# TODO: 改成装饰器 def _log(response: requests.Response) -> bool: """检查发送状态""" r_json = response.json() # https://github.com/danni-cool/wechatbot-webhook?tab=readme-ov-file#%E8%BF%94%E5%9B%9E%E5%80%BC-response-%E7%BB%93%E6%9E%84 if r_json["message"].startswith("Message"): - logger.info("发送消息成功") + pass elif r_json["message"].startswith("Some"): logger.error("发送消息失败,参数校验不通过") elif r_json["message"].startswith("All"): @@ -103,17 +104,27 @@ def wrapper(n, m, *args, **kwargs): @singledispatch -def send_msg(): +def send_msg( + to: Union[str, SendTo], + message: str, + is_group: bool = False, + type: str = "text", + quotable: bool = False, +): """ 发送消息 当传入的第一个参数是字符串时,is_group 默认为 False。 当传入的第一个参数是 SendTo 对象时,is_group 默认为 True。 + 当 quotable 为 Ture 时,该消息为可引用消息。表示该消息被 + 引用回复后,会触发进一步的消息互动。 + :param to: 接收对象的名字或SendTo对象 :param message: 消息内容 :param is_group: 是否为群组(默认值根据 to 的类型而定) :param type: 消息类型,可选 text、fileUrl(默认值为 text) + :param quotable: 是否可引用(默认值为 False) """ pass @@ -121,7 +132,11 @@ def send_msg(): @send_msg.register(str) @_validate def _send_msg1( - name: str, message: str, is_group: bool = False, type: str = "text" + name: str, + message: str, + is_group: bool = False, + type: str = "text", + quotable: bool = False, ) -> None: """ 发送消息 @@ -129,7 +144,10 @@ def _send_msg1( :param message: 消息内容 :param is_group: 是否为群组(默认为个人,False) :param type: 消息类型(text、fileUrl) + :param quotable: 是否可引用(默认为不可引用,False) """ + # if quotable: + # message = f"@{name} {message}" data = { "to": name, "isRoom": is_group, @@ -150,9 +168,9 @@ def _send_msg2(to: SendTo, message: str, is_group: bool = True, type: str = "tex if not is_group: return _send_msg1(to.p_name, message, is_group=False, type=type) - if to.g_name != "": + if to.group: return _send_msg1(to.g_name, message, is_group=True, type=type) - elif to.p_name != "": + elif to.person: return _send_msg1(to.p_name, message, is_group=False, type=type) else: logger.error("发送消息失败,接收者为空") diff --git a/wechatter/sqlite/sqls/ct_chat_conversations.sql b/wechatter/sqlite/sqls/ct_chat_conversations.sql deleted file mode 100644 index bf614d0..0000000 --- a/wechatter/sqlite/sqls/ct_chat_conversations.sql +++ /dev/null @@ -1,20 +0,0 @@ --- 创建 chat_conversations 表 -PRAGMA foreign_keys = false; - -CREATE TABLE IF NOT EXISTS "chat_conversations" -( - conversation_id INTEGER not null - constraint chat_conversations_pk - primary key autoincrement, - chat_id INTEGER - constraint chat_conversations_copilot_chats_chat_id_fk - references copilot_chats, - conversation_timestamp TIMESTAMP, - conversation_role TEXT, - conversation_content TEXT, - constraint conversation_role_check - check ("chat_conversations".conversation_role IN ('system', 'user', 'assistant')) -); - -PRAGMA foreign_keys = true; - diff --git a/wechatter/sqlite/sqls/ct_copilot_chats.sql b/wechatter/sqlite/sqls/ct_copilot_chats.sql deleted file mode 100644 index 672bed9..0000000 --- a/wechatter/sqlite/sqls/ct_copilot_chats.sql +++ /dev/null @@ -1,19 +0,0 @@ --- 创建 copilot_chats 表 -PRAGMA foreign_keys = false; - -CREATE TABLE IF NOT EXISTS "copilot_chats" -( - chat_id integer not null - constraint copilot_chats_pk - primary key autoincrement, - wx_id integer - constraint copilot_chats_wx_users_wx_id_fk - references wx_users, - chat_topic TEXT, - chat_created_time TIMESTAMP default CURRENT_TIMESTAMP, - chat_talk_time TIMESTAMP, - chat_model TEXT, - is_chating BLOB default TRUE -); - -PRAGMA foreign_keys = true; diff --git a/wechatter/sqlite/sqls/ct_wx_users.sql b/wechatter/sqlite/sqls/ct_wx_users.sql deleted file mode 100644 index e79a4a6..0000000 --- a/wechatter/sqlite/sqls/ct_wx_users.sql +++ /dev/null @@ -1,16 +0,0 @@ --- 创建 wx_users 表 -PRAGMA foreign_keys = false; - -CREATE TABLE IF NOT EXISTS "wx_users" -( - wx_id TEXT not null - constraint wx_users_pk - primary key, - wx_name TEXT, - wx_alias TEXT, - wx_gender TEXT, - constraint gender_check - check ("wx_users".wx_gender in ('male', 'female', 'unknown')) -); - -PRAGMA foreign_keys = true; diff --git a/wechatter/utils/time.py b/wechatter/utils/time.py index 49961c6..160c8fe 100644 --- a/wechatter/utils/time.py +++ b/wechatter/utils/time.py @@ -1,5 +1,6 @@ # 获取时间工具类 import time +from datetime import datetime def get_current_hour() -> int: @@ -34,6 +35,14 @@ def get_current_timestamp() -> int: return int(time.time()) +def get_current_datetime_object() -> datetime: + """ + 获取当前时间对象 + :return: 返回当前时间对象 + """ + return datetime.now() + + def get_current_datetime() -> str: """ 获取当前时间 diff --git a/wechatter/webhook_handlers/github/create_handler.py b/wechatter/webhook_handlers/github/create_handler.py index 55035c8..8595344 100644 --- a/wechatter/webhook_handlers/github/create_handler.py +++ b/wechatter/webhook_handlers/github/create_handler.py @@ -9,7 +9,7 @@ def handle_create(data: dict): payload = GithubCreateWebhook(**data) if payload.ref_type == "branch": - logger.info(f"Branche {payload.ref} was created by {payload.sender.login}.") + logger.info(f"Branch {payload.ref} was created by {payload.sender.login}.") message = ( "==== GitHub Create 事件 ====\n" "🆕 有新的分支创建!\n" @@ -18,7 +18,7 @@ def handle_create(data: dict): f"🧑‍💻 创建者:{payload.sender.login}\n" f"🔗 查看详情:{payload.repository.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) elif payload.ref_type == "tag": logger.info(f"Tag {payload.ref} was created by {payload.sender.login}.") message = ( @@ -29,4 +29,4 @@ def handle_create(data: dict): f"🧑‍💻 创建者:{payload.sender.login}\n" f"🔗 查看详情:{payload.repository.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/delete_handler.py b/wechatter/webhook_handlers/github/delete_handler.py index 736c8f7..83ebc64 100644 --- a/wechatter/webhook_handlers/github/delete_handler.py +++ b/wechatter/webhook_handlers/github/delete_handler.py @@ -18,7 +18,7 @@ def handle_delete(data: dict): f"🧑‍💻 创建者:{payload.sender.login}\n" f"🔗 查看详情:{payload.repository.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) elif payload.ref_type == "tag": logger.info(f"Tag {payload.ref} was deleted by {payload.sender.login}") message = ( @@ -29,4 +29,4 @@ def handle_delete(data: dict): f"🧑‍💻 创建者:{payload.sender.login}\n" f"🔗 查看详情:{payload.repository.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/fork_handler.py b/wechatter/webhook_handlers/github/fork_handler.py index 77d4c3a..40e19de 100644 --- a/wechatter/webhook_handlers/github/fork_handler.py +++ b/wechatter/webhook_handlers/github/fork_handler.py @@ -15,4 +15,4 @@ def handle_fork(data: dict): "==== GitHub Fork 事件 ====\n" f"🍴 {payload.repository.full_name} 有新的 Fork!🆙\n" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/issue_comment_handler.py b/wechatter/webhook_handlers/github/issue_comment_handler.py index deb8285..602d394 100644 --- a/wechatter/webhook_handlers/github/issue_comment_handler.py +++ b/wechatter/webhook_handlers/github/issue_comment_handler.py @@ -25,4 +25,4 @@ def handle_issue_comment(data: dict): f"🧑‍💻 创建者:{payload.issue.user.login}\n" f"🔗 查看详情:{payload.issue.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/issue_handler.py b/wechatter/webhook_handlers/github/issue_handler.py index 94a3e7d..b4009b6 100644 --- a/wechatter/webhook_handlers/github/issue_handler.py +++ b/wechatter/webhook_handlers/github/issue_handler.py @@ -19,4 +19,4 @@ def handle_issue(data: dict): f"🧑‍💻 创建者:{payload.issue.user.login}\n" f"🔗 查看详情:{payload.issue.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/ping_handler.py b/wechatter/webhook_handlers/github/ping_handler.py index 3b5a476..3100ea1 100644 --- a/wechatter/webhook_handlers/github/ping_handler.py +++ b/wechatter/webhook_handlers/github/ping_handler.py @@ -16,4 +16,4 @@ def handle_ping(data: dict): f"📚 仓库:{payload.repository.full_name}\n" f"🧑‍💻 触发者:{payload.sender.login}\n" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/pr_handler.py b/wechatter/webhook_handlers/github/pr_handler.py index 7f4ae0c..deb971c 100644 --- a/wechatter/webhook_handlers/github/pr_handler.py +++ b/wechatter/webhook_handlers/github/pr_handler.py @@ -22,7 +22,7 @@ def handle_pr(data: dict): f"🔀 '{payload.pull_request.base.ref}' ⬅ '{payload.pull_request.head.ref}'\n" f"🔗 查看详情:{payload.pull_request.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) return message = ( "==== GitHub Pull Request 事件 ====\n" @@ -33,4 +33,4 @@ def handle_pr(data: dict): f"🧑‍💻 创建者:{payload.pull_request.user.login}\n" f"🔗 查看详情:{payload.pull_request.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/pr_review_handler.py b/wechatter/webhook_handlers/github/pr_review_handler.py index 7bdb684..b7998da 100644 --- a/wechatter/webhook_handlers/github/pr_review_handler.py +++ b/wechatter/webhook_handlers/github/pr_review_handler.py @@ -19,4 +19,4 @@ def handle_pr_review(data: dict): f"🧑‍💻 创建者:{payload.pull_request.user.login}\n" f"🔗 查看详情:{payload.pull_request.html_url}" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/push_handler.py b/wechatter/webhook_handlers/github/push_handler.py index 8928f53..188d405 100644 --- a/wechatter/webhook_handlers/github/push_handler.py +++ b/wechatter/webhook_handlers/github/push_handler.py @@ -26,4 +26,4 @@ def handle_push(data: dict): message += f"📃 提交信息:{payload.commits.pop().message}\n" message += f"🔗 查看详情:{branch_url}" - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) diff --git a/wechatter/webhook_handlers/github/star_handler.py b/wechatter/webhook_handlers/github/star_handler.py index 76c7cf6..8601c55 100644 --- a/wechatter/webhook_handlers/github/star_handler.py +++ b/wechatter/webhook_handlers/github/star_handler.py @@ -14,10 +14,10 @@ def handle_star(data: dict): "==== GitHub Star 事件 ====\n" f"⭐️ {payload.repository.full_name} 的 Star 数量 +1 🆙!\n" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message) else: message = ( "==== GitHub Star 事件 ====\n" f"⭐️ {payload.repository.full_name} 的 Star 数量 -1 🔽!\n" ) - sender.send_msg_to_github_webhook_receivers(message) + sender.mass_send_msg_to_github_webhook_receivers(message)