diff --git a/ai.go b/ai_gemini.go similarity index 97% rename from ai.go rename to ai_gemini.go index 396db6d..1132bc8 100644 --- a/ai.go +++ b/ai_gemini.go @@ -37,7 +37,7 @@ type Data struct { Contents []Content `json:"contents"` } -func askAI(userID int64) string { +func askGemini(userID int64) string { var data Data chats := getChats(userID) if len(chats) > 0 { diff --git a/ai_openai.go b/ai_openai.go new file mode 100644 index 0000000..f66a84f --- /dev/null +++ b/ai_openai.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "encoding/json" + log "github.com/sirupsen/logrus" + "io" + "net/http" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` +} + +type ChatResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` +} + +func mapRole(originalRole string) string { + switch originalRole { + case "USER": + return "user" + case "MODEL": + return "assistant" + default: + return "user" + } +} + +const openAI = "https://gptmos.com/v1/chat/completions" + +func askOpenAI(userID int64) string { + var chatReq ChatRequest + chats := getChats(userID) + if len(chats) > 0 { + for _, chat := range chats { + chatReq.Messages = append(chatReq.Messages, ChatMessage{ + Role: mapRole(chat.Role), + Content: chat.Text, + }) + } + } + + chatReq.Model = "gpt-4-0125-preview" + + jsonData, err := json.Marshal(chatReq) + if err != nil { + log.Errorf("Failed to marshal request: %v", err) + return err.Error() + } + log.Infoln(string(jsonData)) + + // Replace with the actual URL and add authorization headers as needed + req, _ := http.NewRequest("POST", openAI, bytes.NewBuffer(jsonData)) + req.Header.Set("Authorization", "Bearer "+OpenAIKey) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Errorf("Failed to make request: %v", err) + return err.Error() + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + text := string(body) + log.Errorf("Request failed, %s", text) + return text + } + + var chatResp ChatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + log.Errorf("Failed to decode response: %v", err) + return err.Error() + } + + if len(chatResp.Choices) > 0 && len(chatResp.Choices[0].Message.Content) > 0 { + return chatResp.Choices[0].Message.Content + } + + return "no response found" +} diff --git a/config.go b/config.go index f48ddc8..6f2f037 100644 --- a/config.go +++ b/config.go @@ -53,7 +53,7 @@ const ( var apiKey = os.Getenv("GEMINI_API_KEY") var geminiURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=" + apiKey - +var OpenAIKey = os.Getenv("OPENAI_API_KEY") var ( selector = &tb.ReplyMarkup{} btnPrev = selector.Data("Ask AI", "ai-init", "1") diff --git a/handler.go b/handler.go index d790f53..71b4359 100644 --- a/handler.go +++ b/handler.go @@ -56,7 +56,7 @@ func mainEntrance(c tb.Context) error { } else { addChat(c.Sender().ID, userRole, c.Message().Text) } - aiResponse := askAI(c.Sender().ID) + aiResponse := askGemini(c.Sender().ID) addChat(c.Sender().ID, modelRole, aiResponse) return c.Send(aiResponse, tb.NoPreview) } else { @@ -84,7 +84,7 @@ func testEntrance(c tb.Context) error { } else { addChat(c.Sender().ID, userRole, c.Message().Text) } - aiResponse := askAI(c.Sender().ID) + aiResponse := askGemini(c.Sender().ID) addChat(c.Sender().ID, modelRole, aiResponse) return c.Send(aiResponse, tb.NoPreview) } else {