Skip to content

Commit

Permalink
gpt: drop resty and slash cmd, trim oputput to discord limit
Browse files Browse the repository at this point in the history
  • Loading branch information
inciner8r committed Mar 2, 2024
1 parent 49b72f0 commit ff59e0c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 138 deletions.
1 change: 0 additions & 1 deletion app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ func Init() {
<-sc
fmt.Println("shutting down")
// Cleanly close down the Discord session.
sess.Close()
}
42 changes: 0 additions & 42 deletions app/commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,6 @@ func RegisterCommands() []*discordgo.ApplicationCommand {
},
},
},
{
Name: "gpt",
Description: "Generate text with gpt-4",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "prompt",
Description: "prompt to generate text",
Required: true,
},
},
},
{
Name: "upscale",
Description: "Upscale one of the generated image",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "choice",
Description: "choice of image to upscale",
Required: true,
Choices: []*discordgo.ApplicationCommandOptionChoice{
{
Name: "1",
Value: "1",
},
{
Name: "2",
Value: "2",
},
{
Name: "3",
Value: "3",
},
{
Name: "4",
Value: "4",
},
},
},
},
},
}
return commands
}
104 changes: 31 additions & 73 deletions app/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func AddHandlers(sess *discordgo.Session) {
registeredCommands[i] = cmd
fmt.Println("command registered: ", cmd.Name)
}

defer sess.Close()
commandHandlers := map[string]func(s *discordgo.Session, i *discordgo.InteractionCreate){
"generate": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
options := i.ApplicationCommandData().Options
Expand Down Expand Up @@ -52,78 +54,6 @@ func AddHandlers(sess *discordgo.Session) {

}
},
"gpt": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
options := i.ApplicationCommandData().Options

optionMap := make(map[string]*discordgo.ApplicationCommandInteractionDataOption, len(options))
for _, opt := range options {
optionMap[opt.Name] = opt
}

margs := make([]string, 0, len(options))
msgformat := "Take a look at your response:\n"

if option, ok := optionMap["prompt"]; ok {
margs = append(margs, option.StringValue())
prompt := strings.Join(margs[:], " ")
res, err := chatgpt.GetChatGPTResponse(prompt)
msg := msgformat + res
fmt.Println("reply: ", msg)
if err != nil {
fmt.Println("error in generating response:", err.Error())
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "error generating response",
},
})
}
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: msg,
},
})
}
},
"upscale": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
options := i.ApplicationCommandData().Options
choice := options[0].StringValue()
choiceInt, err := strconv.Atoi(choice)
if err != nil {
fmt.Println("error in upscaling image:", err.Error())
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "error upscaling image",
},
})
}
repliedMessageID := i.Message.MessageReference.MessageID
imageURL, _, err := getImageFromMessageID(s, os.Getenv("CHANNEL_ID"), repliedMessageID)
if err != nil {
fmt.Println("error in upscaling image:", err.Error())
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "error upscaling image",
},
})
}

sess_id := s.State.SessionID
nonce := fmt.Sprint(rand.Int())
err = Upscale(int(choiceInt), repliedMessageID, imageURL, sess_id, nonce)
if err != nil {
fmt.Println("error in upscaling image:", err.Error())
s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: "error upscaling image",
},
})
}
},
}
const prefix = "!airbot"
sess.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) {
Expand Down Expand Up @@ -246,6 +176,35 @@ func AddHandlers(sess *discordgo.Session) {
UpscaleCreative(number, repliedMessageID, imageID, sess_id, nonce)
}
}
if args[1] == "gpt" {
parts := strings.SplitN(m.Content, " ", 3)
if len(parts) < 3 {
s.ChannelMessageSend(m.ChannelID, "Invalid format. Usage: !airbot gpt <prompt>")
return
}
prompt := parts[2]

res, err := chatgpt.GetChatGPTResponse(prompt)
if err != nil {
fmt.Println("Error generating response:", err.Error())
s.ChannelMessageSend(m.ChannelID, "Error generating response.")
return
}
fmt.Println("res", res)
// Truncate the response if it exceeds Discord's maximum message length
if len(res) > 2000 {
res = res[:2000]
}
reply := &discordgo.MessageReference{
MessageID: m.ID,
}
_, err = s.ChannelMessageSendReply(m.ChannelID, res, reply)
if err != nil {
fmt.Println("Error sending message reply:", err.Error())
return
}
}

if args[1] == "help" {
const helpMessage = "Available commands:\n" +
"1. /generate <prompt>: Generates text based on the provided prompt.\n" +
Expand All @@ -262,7 +221,6 @@ func AddHandlers(sess *discordgo.Session) {
MessageID: m.ID,
}
s.ChannelMessageSendReply(m.ChannelID, helpMessage, reply)
return
}
})
}
57 changes: 35 additions & 22 deletions utils/chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package chatgpt

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"

"github.com/go-resty/resty/v2"
)

var (
chatGPTAPIKey = os.Getenv("OPENAI_KEY")
chatGPTURL = "https://api.openai.com/v1/chat/completions"
)

func GetChatGPTResponse(prompt string) (string, error) {
client := resty.New()
var result response
request := request{
Model: "gpt-4",
var (
chatGPTAPIKey = os.Getenv("OPENAI_KEY")
chatGPTURL = "https://api.openai.com/v1/chat/completions"
)
requestData := request{
Model: "gpt-4-turbo-preview",
Messages: []struct {
Role string `json:"role"`
Content string `json:"content"`
Expand All @@ -29,28 +26,44 @@ func GetChatGPTResponse(prompt string) (string, error) {
},
}

resp, err := client.R().
SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+chatGPTAPIKey).
SetBody(request).SetResult(&result).
Post(chatGPTURL)
requestBody, err := json.Marshal(requestData)
if err != nil {
return "", err
}

req, err := http.NewRequest("POST", chatGPTURL, bytes.NewBuffer(requestBody))
if err != nil {
return "", err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+chatGPTAPIKey)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", errors.New("failed to fetch response. status code: " + resp.Status)
}

fmt.Println("POST Response:", resp.Status())
fmt.Println(len(result.Choices) >= 1)
var result response
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return "", err
}
if len(result.Choices) < 1 {
return "", errors.New("failed to fetch response")
}
return result.Choices[0].Message.Content, nil
}

type request struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []struct {
Model string `json:"model"`
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
Expand Down

0 comments on commit ff59e0c

Please sign in to comment.