Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 新增图片推理模式来支持gpt4v #265

Merged
merged 13 commits into from
Nov 20, 2023
6 changes: 3 additions & 3 deletions code/handlers/card_common_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"start-feishubot/logger"

larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
)

Expand All @@ -20,11 +18,13 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
handlers := []CardHandlerMeta{
NewClearCardHandler,
NewPicResolutionHandler,
NewVisionResolutionHandler,
NewPicTextMoreHandler,
NewPicModeChangeHandler,
NewRoleTagCardHandler,
NewRoleCardHandler,
NewAIModeCardHandler,
NewVisionModeChangeHandler,
}

return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
Expand All @@ -35,7 +35,7 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
return nil, err
}
//pp.Println(cardMsg)
logger.Debug("cardMsg ", cardMsg)
//logger.Debug("cardMsg ", cardMsg)
for _, handler := range handlers {
h := handler(cardMsg, m)
i, err := h(ctx, cardAction)
Expand Down
74 changes: 74 additions & 0 deletions code/handlers/card_vision_action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package handlers

import (
"context"
"fmt"
larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"start-feishubot/services"
)

func NewVisionResolutionHandler(cardMsg CardMsg,
m MessageHandler) CardHandlerFunc {
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
if cardMsg.Kind == VisionStyleKind {
CommonProcessVisionStyle(cardMsg, cardAction, m.sessionCache)
return nil, nil
}
return nil, ErrNextHandler
}
}
func NewVisionModeChangeHandler(cardMsg CardMsg,
m MessageHandler) CardHandlerFunc {
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
if cardMsg.Kind == VisionModeChangeKind {
newCard, err, done := CommonProcessVisionModeChange(cardMsg, m.sessionCache)
if done {
return newCard, err
}
return nil, nil
}
return nil, ErrNextHandler
}
}

func CommonProcessVisionStyle(msg CardMsg,
cardAction *larkcard.CardAction,
cache services.SessionServiceCacheInterface) {
option := cardAction.Action.Option
fmt.Println(larkcore.Prettify(msg))
cache.SetVisionDetail(msg.SessionId, services.VisionDetail(option))
//send text
replyMsg(context.Background(), "图片解析度调整为:"+option,
&msg.MsgId)
}

func CommonProcessVisionModeChange(cardMsg CardMsg,
session services.SessionServiceCacheInterface) (
interface{}, error, bool) {
if cardMsg.Value == "1" {

sessionId := cardMsg.SessionId
session.Clear(sessionId)
session.SetMode(sessionId,
services.ModeVision)
session.SetVisionDetail(sessionId,
services.VisionDetailLow)

newCard, _ :=
newSendCard(
withHeader("🕵️️ 已进入图片推理模式", larkcard.TemplateBlue),
withVisionDetailLevelBtn(&sessionId),
withNote("提醒:回复图片,让LLM和你一起推理图片的内容。"))
return newCard, nil, true
}
if cardMsg.Value == "0" {
newCard, _ := newSendCard(
withHeader("️🎒 机器人提醒", larkcard.TemplateGreen),
withMainMd("依旧保留此话题的上下文信息"),
withNote("我们可以继续探讨这个话题,期待和您聊天。如果您有其他问题或者想要讨论的话题,请告诉我哦"),
)
return newCard, nil, true
}
return nil, nil, false
}
28 changes: 27 additions & 1 deletion code/handlers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ func msgFilter(msg string) string {
//replace @到下一个非空的字段 为 ''
regex := regexp.MustCompile(`@[^ ]*`)
return regex.ReplaceAllString(msg, "")

}

// Parse rich text json to text
Expand Down Expand Up @@ -47,6 +46,33 @@ func parsePostContent(content string) string {
return msgFilter(text)
}

func parsePostImageKeys(content string) []string {
var contentMap map[string]interface{}
err := json.Unmarshal([]byte(content), &contentMap)

if err != nil {
fmt.Println(err)
return nil
}

var imageKeys []string

if contentMap["content"] == nil {
return imageKeys
}

contentList := contentMap["content"].([]interface{})
for _, v := range contentList {
for _, v1 := range v.([]interface{}) {
if v1.(map[string]interface{})["tag"] == "img" {
imageKeys = append(imageKeys, v1.(map[string]interface{})["image_key"].(string))
}
}
}

return imageKeys
}

func parseContent(content, msgType string) string {
//"{\"text\":\"@_user_1 hahaha\"}",
//only get text content hahaha
Expand Down
1 change: 1 addition & 0 deletions code/handlers/event_common_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type MsgInfo struct {
qParsed string
fileKey string
imageKey string
imageKeys []string // post 消息卡片中的图片组
sessionId *string
mention []*larkim.MentionEvent
}
Expand Down
17 changes: 17 additions & 0 deletions code/handlers/event_msg_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ func setDefaultPrompt(msg []openai.Messages) []openai.Messages {
return msg
}

//func setDefaultVisionPrompt(msg []openai.VisionMessages) []openai.VisionMessages {
// if !hasSystemRole(msg) {
// msg = append(msg, openai.VisionMessages{
// Role: "system", Content: []openai.ContentType{
// {Type: "text", Text: "You are ChatGPT4V, " +
// "You are ChatGPT4V, " +
// "a large language and picture model trained by" +
// " OpenAI. " +
// "Answer in user's language as concisely as" +
// " possible. Knowledge cutoff: 20230601 " +
// "Current date" + time.Now().Format("20060102"),
// }},
// })
// }
// return msg
//}

type MessageAction struct { /*消息*/
}

Expand Down
160 changes: 160 additions & 0 deletions code/handlers/event_vision_action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package handlers

import (
"context"
"fmt"
"os"
"start-feishubot/initialization"
"start-feishubot/services"
"start-feishubot/services/openai"
"start-feishubot/utils"

larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)

type VisionAction struct { /*图片推理*/
}

func (va *VisionAction) Execute(a *ActionInfo) bool {
if !AzureModeCheck(a) {
return true
}

if isVisionCommand(a) {
initializeVisionMode(a)
sendVisionInstructionCard(*a.ctx, a.info.sessionId, a.info.msgId)
return false
}

mode := a.handler.sessionCache.GetMode(*a.info.sessionId)

if a.info.msgType == "image" {
if mode != services.ModeVision {
sendVisionModeCheckCard(*a.ctx, a.info.sessionId, a.info.msgId)
return false
}

return va.handleVisionImage(a)
}

if a.info.msgType == "post" && mode == services.ModeVision {
return va.handleVisionPost(a)
}

return true
}

func isVisionCommand(a *ActionInfo) bool {
_, foundPic := utils.EitherTrimEqual(a.info.qParsed, "/vision", "图片推理")
return foundPic
}

func initializeVisionMode(a *ActionInfo) {
a.handler.sessionCache.Clear(*a.info.sessionId)
a.handler.sessionCache.SetMode(*a.info.sessionId, services.ModeVision)
a.handler.sessionCache.SetVisionDetail(*a.info.sessionId, services.VisionDetailHigh)
}

func (va *VisionAction) handleVisionImage(a *ActionInfo) bool {
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
base64, err := downloadAndEncodeImage(a.info.imageKey, a.info.msgId)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}

return va.processImageAndReply(a, base64, detail)
}

func (va *VisionAction) handleVisionPost(a *ActionInfo) bool {
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
var base64s []string

for _, imageKey := range a.info.imageKeys {
if imageKey == "" {
continue
}
base64, err := downloadAndEncodeImage(imageKey, a.info.msgId)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
base64s = append(base64s, base64)
}

if len(base64s) == 0 {
replyMsg(*a.ctx, "🤖️:请发送一张图片", a.info.msgId)
return false
}

return va.processMultipleImagesAndReply(a, base64s, detail)
}

func downloadAndEncodeImage(imageKey string, msgId *string) (string, error) {
f := fmt.Sprintf("%s.png", imageKey)
defer os.Remove(f)

req := larkim.NewGetMessageResourceReqBuilder().MessageId(*msgId).FileKey(imageKey).Type("image").Build()
resp, err := initialization.GetLarkClient().Im.MessageResource.Get(context.Background(), req)
if err != nil {
return "", err
}

resp.WriteFile(f)
return openai.GetBase64FromImage(f)
}

func replyWithErrorMsg(ctx context.Context, err error, msgId *string) {
replyMsg(ctx, fmt.Sprintf("🤖️:图片下载失败,请稍后再试~\n 错误信息: %v", err), msgId)
}

func (va *VisionAction) processImageAndReply(a *ActionInfo, base64 string, detail string) bool {
msg := createVisionMessages("解释这个图片", base64, detail)
completions, err := a.handler.gpt.GetVisionInfo(msg)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
return false
}

func (va *VisionAction) processMultipleImagesAndReply(a *ActionInfo, base64s []string, detail string) bool {
msg := createMultipleVisionMessages(a.info.qParsed, base64s, detail)
completions, err := a.handler.gpt.GetVisionInfo(msg)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
return false
}

func createVisionMessages(query, base64Image, detail string) []openai.VisionMessages {
return []openai.VisionMessages{
{
Role: "user",
Content: []openai.ContentType{
{Type: "text", Text: query},
{Type: "image_url", ImageURL: &openai.ImageURL{
URL: "data:image/jpeg;base64," + base64Image,
Detail: detail,
}},
},
},
}
}

func createMultipleVisionMessages(query string, base64Images []string, detail string) []openai.VisionMessages {
content := []openai.ContentType{{Type: "text", Text: query}}
for _, base64Image := range base64Images {
content = append(content, openai.ContentType{
Type: "image_url",
ImageURL: &openai.ImageURL{
URL: "data:image/jpeg;base64," + base64Image,
Detail: detail,
},
})
}
return []openai.VisionMessages{{Role: "user", Content: content}}
}
5 changes: 3 additions & 2 deletions code/handlers/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
qParsed: strings.Trim(parseContent(*content, msgType), " "),
fileKey: parseFileKey(*content),
imageKey: parseImageKey(*content),
imageKeys: parsePostImageKeys(*content),
sessionId: sessionId,
mention: mention,
}
Expand All @@ -94,17 +95,17 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
&ProcessedUniqueAction{}, //避免重复处理
&ProcessMentionAction{}, //判断机器人是否应该被调用
&AudioAction{}, //语音处理
&EmptyAction{}, //空消息处理
&ClearAction{}, //清除消息处理
&VisionAction{}, //图片推理处理
&PicAction{}, //图片处理
&AIModeAction{}, //模式切换处理
&RoleListAction{}, //角色列表处理
&HelpAction{}, //帮助处理
&BalanceAction{}, //余额处理
&RolePlayAction{}, //角色扮演处理
&MessageAction{}, //消息处理
&EmptyAction{}, //空消息处理
&StreamMessageAction{}, //流式消息处理

}
chain(data, actions...)
return nil
Expand Down
Loading
Loading