diff --git a/.gitignore b/.gitignore index 5cb49d32..5d7ebc65 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ docker.md # Mac OS .DS_Store **/.DS_Store +*.pem diff --git a/code/.gitignore b/code/.gitignore new file mode 100644 index 00000000..d742ec84 --- /dev/null +++ b/code/.gitignore @@ -0,0 +1,2 @@ +/apikey_usage.json +*.pem diff --git a/code/config.example.yaml b/code/config.example.yaml index 54d543a5..14744be7 100644 --- a/code/config.example.yaml +++ b/code/config.example.yaml @@ -1,10 +1,16 @@ -# 飞书 +# 飞书 # 不要随意修改example中配置的顺序和空格,否则docker的sed脚本会执行出错 APP_ID: cli_axxx APP_SECRET: xxx APP_ENCRYPT_KEY: xxx APP_VERIFICATION_TOKEN: xxx -# 请确保和飞书应用管理平台中的设置一致 +# 请确保和飞书应用管理平台中的设置一致 BOT_NAME: chatGpt -# openAI -OPENAI_KEY: sk-xxx +# openAI key 支持负载均衡 可以填写多个key 用逗号分隔 +OPENAI_KEY: sk-xxx,sk-xxx,sk-xxx +# 服务器配置 +HTTP_PORT: 9000 +HTTPS_PORT: 9001 +USE_HTTPS: false +CERT_FILE: cert.pem +KEY_FILE: key.pem diff --git a/code/handlers/common.go b/code/handlers/common.go index a2583862..bc510496 100644 --- a/code/handlers/common.go +++ b/code/handlers/common.go @@ -8,7 +8,7 @@ import ( "strings" ) -//func sendCard +// func sendCard func msgFilter(msg string) string { //replace @到下一个非空的字段 为 '' regex := regexp.MustCompile(`@[^ ]*`) @@ -50,14 +50,14 @@ func processQuote(msg string) string { return strings.Replace(msg, "\\\"", "\"", -1) } -//将字符中 \u003c 替换为 < 等等 +// 将字符中 \u003c 替换为 < 等等 func processUnicode(msg string) string { regex := regexp.MustCompile(`\\u[0-9a-fA-F]{4}`) return regex.ReplaceAllStringFunc(msg, func(s string) string { r, _ := regexp.Compile(`\\u`) s = r.ReplaceAllString(s, "") i, _ := strconv.ParseInt(s, 16, 32) - return string(i) + return strconv.Itoa(int(i)) }) } diff --git a/code/handlers/handler.go b/code/handlers/handler.go index 2f46a5ae..95b13d23 100644 --- a/code/handlers/handler.go +++ b/code/handlers/handler.go @@ -13,7 +13,7 @@ import ( larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" ) -//责任链 +// 责任链 func chain(data *ActionInfo, actions ...Action) bool { for _, v := range actions { if !v.Execute(data) { @@ -26,7 +26,7 @@ func chain(data *ActionInfo, actions ...Action) bool { type MessageHandler struct { sessionCache services.SessionServiceCacheInterface msgCache services.MsgCacheInterface - gpt services.ChatGPT + gpt *services.ChatGPT config initialization.Config } @@ -153,7 +153,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2 var _ MessageHandlerInterface = (*MessageHandler)(nil) -func NewMessageHandler(gpt services.ChatGPT, +func NewMessageHandler(gpt *services.ChatGPT, config initialization.Config) MessageHandlerInterface { return &MessageHandler{ sessionCache: services.GetSessionCache(), diff --git a/code/handlers/init.go b/code/handlers/init.go index 76ce0e0c..8b729949 100644 --- a/code/handlers/init.go +++ b/code/handlers/init.go @@ -25,7 +25,7 @@ const ( // handlers 所有消息类型类型的处理器 var handlers MessageHandlerInterface -func InitHandlers(gpt services.ChatGPT, config initialization.Config) { +func InitHandlers(gpt *services.ChatGPT, config initialization.Config) { handlers = NewMessageHandler(gpt, config) } @@ -50,7 +50,7 @@ func CardHandler() func(ctx context.Context, func judgeCardType(cardAction *larkcard.CardAction) HandlerType { actionValue := cardAction.Action.Value chatType := actionValue["chatType"] - fmt.Printf("chatType: %v", chatType) + //fmt.Printf("chatType: %v", chatType) if chatType == "group" { return GroupHandler } diff --git a/code/initialization/config.go b/code/initialization/config.go index 3ef97051..5a8ac829 100644 --- a/code/initialization/config.go +++ b/code/initialization/config.go @@ -2,6 +2,9 @@ package initialization import ( "fmt" + "os" + "strconv" + "strings" "github.com/spf13/viper" ) @@ -12,29 +15,103 @@ type Config struct { FeishuAppEncryptKey string FeishuAppVerificationToken string FeishuBotName string - OpenaiApiKey string + OpenaiApiKeys []string + HttpPort int + HttpsPort int + UseHttps bool + CertFile string + KeyFile string } func LoadConfig(cfg string) *Config { viper.SetConfigFile(cfg) viper.ReadInConfig() viper.AutomaticEnv() + //content, err := ioutil.ReadFile("config.yaml") + //if err != nil { + // fmt.Println("Error reading file:", err) + //} + //fmt.Println(string(content)) - return &Config{ - FeishuAppId: getViperStringValue("APP_ID"), - FeishuAppSecret: getViperStringValue("APP_SECRET"), - FeishuAppEncryptKey: getViperStringValue("APP_ENCRYPT_KEY"), - FeishuAppVerificationToken: getViperStringValue("APP_VERIFICATION_TOKEN"), - FeishuBotName: getViperStringValue("BOT_NAME"), - OpenaiApiKey: getViperStringValue("OPENAI_KEY"), + config := &Config{ + FeishuAppId: getViperStringValue("APP_ID", ""), + FeishuAppSecret: getViperStringValue("APP_SECRET", ""), + FeishuAppEncryptKey: getViperStringValue("APP_ENCRYPT_KEY", ""), + FeishuAppVerificationToken: getViperStringValue("APP_VERIFICATION_TOKEN", ""), + FeishuBotName: getViperStringValue("BOT_NAME", ""), + OpenaiApiKeys: getViperStringArray("OPENAI_KEY", nil), + HttpPort: getViperIntValue("HTTP_PORT", 9000), + HttpsPort: getViperIntValue("HTTPS_PORT", 9001), + UseHttps: getViperBoolValue("USE_HTTPS", false), + CertFile: getViperStringValue("CERT_FILE", "cert.pem"), + KeyFile: getViperStringValue("KEY_FILE", "key.pem"), } + return config } -func getViperStringValue(key string) string { +func getViperStringValue(key string, defaultValue string) string { value := viper.GetString(key) if value == "" { - panic(fmt.Errorf("%s MUST be provided in environment or config.yaml file", key)) + return defaultValue } return value } + +//OPENAI_KEY: sk-xxx,sk-xxx,sk-xxx +//result:[sk-xxx sk-xxx sk-xxx] +func getViperStringArray(key string, defaultValue []string) []string { + value := viper.GetString(key) + if value == "" { + return defaultValue + } + return strings.Split(value, ",") +} + +func getViperIntValue(key string, defaultValue int) int { + value := viper.GetString(key) + if value == "" { + return defaultValue + } + intValue, err := strconv.Atoi(value) + if err != nil { + fmt.Printf("Invalid value for %s, using default value %d\n", key, defaultValue) + return defaultValue + } + return intValue +} + +func getViperBoolValue(key string, defaultValue bool) bool { + value := viper.GetString(key) + if value == "" { + return defaultValue + } + boolValue, err := strconv.ParseBool(value) + if err != nil { + fmt.Printf("Invalid value for %s, using default value %v\n", key, defaultValue) + return defaultValue + } + return boolValue +} + +func (config *Config) GetCertFile() string { + if config.CertFile == "" { + return "cert.pem" + } + if _, err := os.Stat(config.CertFile); err != nil { + fmt.Printf("Certificate file %s does not exist, using default file cert.pem\n", config.CertFile) + return "cert.pem" + } + return config.CertFile +} + +func (config *Config) GetKeyFile() string { + if config.KeyFile == "" { + return "key.pem" + } + if _, err := os.Stat(config.KeyFile); err != nil { + fmt.Printf("Key file %s does not exist, using default file key.pem\n", config.KeyFile) + return "key.pem" + } + return config.KeyFile +} diff --git a/code/initialization/gin.go b/code/initialization/gin.go new file mode 100644 index 00000000..af645350 --- /dev/null +++ b/code/initialization/gin.go @@ -0,0 +1,58 @@ +package initialization + +import ( + "crypto/tls" + "fmt" + "github.com/gin-gonic/gin" + "log" + "net/http" + "time" +) + +func loadCertificate(config Config) (cert tls.Certificate, err error) { + cert, err = tls.LoadX509KeyPair(config.CertFile, config.KeyFile) + if err != nil { + return cert, fmt.Errorf("failed to load certificate: %v", err) + } + // check certificate expiry + certExpiry := cert.Leaf.NotAfter + if certExpiry.Before(time.Now()) { + return cert, fmt.Errorf("certificate expired on %v", certExpiry) + } + return cert, nil +} +func startHTTPServer(config Config, r *gin.Engine) (err error) { + log.Printf("http server started: http://localhost:%d/webhook/event\n", config.HttpPort) + err = r.Run(fmt.Sprintf(":%d", config.HttpPort)) + if err != nil { + return fmt.Errorf("failed to start http server: %v", err) + } + return nil +} +func startHTTPSServer(config Config, r *gin.Engine) (err error) { + cert, err := loadCertificate(config) + if err != nil { + return fmt.Errorf("failed to load certificate: %v", err) + } + server := &http.Server{ + Addr: fmt.Sprintf(":%d", config.HttpsPort), + Handler: r, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } + fmt.Printf("https server started: https://localhost:%d/webhook/event\n", config.HttpsPort) + err = server.ListenAndServeTLS("", "") + if err != nil { + return fmt.Errorf("failed to start https server: %v", err) + } + return nil +} +func StartServer(config Config, r *gin.Engine) (err error) { + if config.UseHttps { + err = startHTTPSServer(config, r) + } else { + err = startHTTPServer(config, r) + } + return err +} diff --git a/code/main.go b/code/main.go index 6b3eaaaf..f448fc10 100644 --- a/code/main.go +++ b/code/main.go @@ -2,8 +2,8 @@ package main import ( "context" - "fmt" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + "log" "start-feishubot/handlers" "start-feishubot/initialization" "start-feishubot/services" @@ -26,9 +26,8 @@ func main() { pflag.Parse() config := initialization.LoadConfig(*cfg) initialization.LoadLarkClient(*config) - - gpt := services.NewChatGPT(config.OpenaiApiKey) - handlers.InitHandlers(*gpt, *config) + gpt := services.NewChatGPT(config.OpenaiApiKeys) + handlers.InitHandlers(gpt, *config) eventHandler := dispatcher.NewEventDispatcher( config.FeishuAppVerificationToken, config.FeishuAppEncryptKey). @@ -53,8 +52,9 @@ func main() { sdkginext.NewCardActionHandlerFunc( cardHandler)) - fmt.Println("http server started", - "http://localhost:9000/webhook/event") - r.Run(":9000") + err := initialization.StartServer(*config, r) + if err != nil { + log.Fatalf("failed to start server: %v", err) + } } diff --git a/code/services/gpt3.go b/code/services/gpt3.go index 18e006bf..549426ae 100644 --- a/code/services/gpt3.go +++ b/code/services/gpt3.go @@ -3,9 +3,11 @@ package services import ( "bytes" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" + "start-feishubot/services/loadbalancer" "strings" "time" ) @@ -48,7 +50,8 @@ type ChatGPTRequestBody struct { PresencePenalty int `json:"presence_penalty"` } type ChatGPT struct { - ApiKey string + Lb *loadbalancer.LoadBalancer + ApiKey []string } type ImageGenerationRequestBody struct { @@ -65,7 +68,13 @@ type ImageGenerationResponseBody struct { } `json:"data"` } -func (gpt ChatGPT) sendRequest(url, method string, requestBody interface{}, responseBody interface{}) error { +func (gpt ChatGPT) sendRequest(url, method string, + requestBody interface{}, responseBody interface{}) error { + api := gpt.Lb.GetAPI() + if api == nil { + return errors.New("no available API") + } + requestData, err := json.Marshal(requestBody) if err != nil { return err @@ -77,16 +86,20 @@ func (gpt ChatGPT) sendRequest(url, method string, requestBody interface{}, resp } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+gpt.ApiKey) + req.Header.Set("Authorization", "Bearer "+api.Key) client := &http.Client{Timeout: 110 * time.Second} response, err := client.Do(req) if err != nil { + gpt.Lb.SetAvailability(api.Key, false) return err } defer response.Body.Close() + if response.StatusCode/2 != 100 { + gpt.Lb.SetAvailability(api.Key, false) return fmt.Errorf("%s api %s", strings.ToUpper(method), response.Status) } + body, err := ioutil.ReadAll(response.Body) if err != nil { return err @@ -96,6 +109,8 @@ func (gpt ChatGPT) sendRequest(url, method string, requestBody interface{}, resp if err != nil { return err } + + gpt.Lb.SetAvailability(api.Key, true) return nil } @@ -109,9 +124,9 @@ func (gpt ChatGPT) Completions(msg []Messages) (resp Messages, err error) { FrequencyPenalty: 0, PresencePenalty: 0, } - gptResponseBody := &ChatGPTResponseBody{} - err = gpt.sendRequest(BASEURL+"chat/completions", "POST", requestBody, gptResponseBody) + err = gpt.sendRequest(BASEURL+"chat/completions", "POST", + requestBody, gptResponseBody) if err == nil { resp = gptResponseBody.Choices[0].Message @@ -149,8 +164,10 @@ func (gpt ChatGPT) GenerateOneImage(prompt string, size string) (string, error) return b64s[0], nil } -func NewChatGPT(apiKey string) *ChatGPT { +func NewChatGPT(apiKeys []string) *ChatGPT { + lb := loadbalancer.NewLoadBalancer(apiKeys) return &ChatGPT{ - ApiKey: apiKey, + Lb: lb, + ApiKey: apiKeys, } } diff --git a/code/services/gpt3_test.go b/code/services/gpt3_test.go index 4133e636..f9cb20b1 100644 --- a/code/services/gpt3_test.go +++ b/code/services/gpt3_test.go @@ -14,8 +14,9 @@ func TestCompletions(t *testing.T) { {Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."}, } - chatGpt := &ChatGPT{ApiKey: config.OpenaiApiKey} - resp, err := chatGpt.Completions(msgs) + gpt := NewChatGPT(config.OpenaiApiKeys) + + resp, err := gpt.Completions(msgs) if err != nil { t.Errorf("TestCompletions failed with error: %v", err) } @@ -26,7 +27,7 @@ func TestCompletions(t *testing.T) { func TestGenerateOneImage(t *testing.T) { config := initialization.LoadConfig("../config.yaml") - gpt := ChatGPT{ApiKey: config.OpenaiApiKey} + gpt := NewChatGPT(config.OpenaiApiKeys) prompt := "a red apple" size := "256x256" diff --git a/code/services/loadbalancer/loadbalancer.go b/code/services/loadbalancer/loadbalancer.go new file mode 100644 index 00000000..ca460de4 --- /dev/null +++ b/code/services/loadbalancer/loadbalancer.go @@ -0,0 +1,98 @@ +package loadbalancer + +import ( + "sync" +) + +type API struct { + Key string + Times uint32 + Available bool +} + +type LoadBalancer struct { + apis []*API + mu sync.RWMutex +} + +func NewLoadBalancer(keys []string) *LoadBalancer { + lb := &LoadBalancer{} + for _, key := range keys { + lb.apis = append(lb.apis, &API{Key: key}) + } + //SetAvailabilityForAll true + lb.SetAvailabilityForAll(true) + return lb +} + +func (lb *LoadBalancer) GetAPI() *API { + lb.mu.RLock() + defer lb.mu.RUnlock() + + var availableAPIs []*API + for _, api := range lb.apis { + if api.Available { + availableAPIs = append(availableAPIs, api) + } + } + if len(availableAPIs) == 0 { + return nil + } + + selectedAPI := availableAPIs[0] + minTimes := selectedAPI.Times + for _, api := range availableAPIs { + if api.Times < minTimes { + selectedAPI = api + minTimes = api.Times + } + } + selectedAPI.Times++ + //fmt.Printf("API Availability:\n") + //for _, api := range lb.apis { + // fmt.Printf("%s: %v\n", api.Key, api.Available) + // fmt.Printf("%s: %d\n", api.Key, api.Times) + //} + + return selectedAPI +} +func (lb *LoadBalancer) SetAvailability(key string, available bool) { + lb.mu.Lock() + defer lb.mu.Unlock() + + for _, api := range lb.apis { + if api.Key == key { + api.Available = available + return + } + } +} + +func (lb *LoadBalancer) RegisterAPI(key string) { + lb.mu.Lock() + defer lb.mu.Unlock() + + if lb.apis == nil { + lb.apis = make([]*API, 0) + } + + lb.apis = append(lb.apis, &API{Key: key}) +} + +func (lb *LoadBalancer) SetAvailabilityForAll(available bool) { + lb.mu.Lock() + defer lb.mu.Unlock() + + for _, api := range lb.apis { + api.Available = available + } +} + +func (lb *LoadBalancer) GetAPIs() []*API { + lb.mu.RLock() + defer lb.mu.RUnlock() + + apis := make([]*API, len(lb.apis)) + copy(apis, lb.apis) + return apis +} diff --git a/entrypoint.sh b/entrypoint.sh index d969ef7b..cf79e69e 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,4 +1,8 @@ #!/bin/bash + + +#用来从环境变量中获取配置信息,将其写入到配置文件 +#默认值已在config层完成,这里不再重复 set -e APP_ID=${APP_ID:-""} @@ -7,9 +11,15 @@ APP_ENCRYPT_KEY=${APP_ENCRYPT_KEY:-""} APP_VERIFICATION_TOKEN=${APP_VERIFICATION_TOKEN:-""} BOT_NAME=${BOT_NAME:-""} OPENAI_KEY=${OPENAI_KEY:-""} +HTTP_PORT=${HTTP_PORT:-""} +HTTPS_PORT=${HTTPS_PORT:-""} +USE_HTTPS=${USE_HTTPS:-""} +CERT_FILE=${CERT_FILE:-""} +KEY_FILE=${KEY_FILE:-""} CONFIG_PATH=${CONFIG_PATH:-"config.yaml"} + # modify content in config.yaml if [ "$APP_ID" != "" ] ; then sed -i "2c APP_ID: $APP_ID" $CONFIG_PATH @@ -41,11 +51,34 @@ else echo -e "\033[31m[Warning] You need to set BOT_NAME before running!\033[0m" fi - if [ "$OPENAI_KEY" != "" ] ; then sed -i "9c OPENAI_KEY: $OPENAI_KEY" $CONFIG_PATH else echo -e "\033[31m[Warning] You need to set OPENAI_KEY before running!\033[0m" fi + +# 以下为可选配置 +if [ "$HTTP_PORT" != "" ] ; then +sed -i "11c HTTP_PORT: $HTTP_PORT" $CONFIG_PATH +fi + +if [ "$HTTPS_PORT" != "" ] ; then +sed -i "12c HTTPS_PORT: $HTTPS_PORT" $CONFIG_PATH +fi + +if [ "$USE_HTTPS" != "" ] ; then +sed -i "13c USE_HTTPS: $USE_HTTPS" $CONFIG_PATH +fi + +if [ "$CERT_FILE" != "" ] ; then +sed -i "14c CERT_FILE: $CERT_FILE" $CONFIG_PATH +fi + +if [ "$KEY_FILE" != "" ] ; then +sed -i "15c KEY_FILE: $KEY_FILE" $CONFIG_PATH +fi + +echo -e "\033[32m[Success] Configuration file has been generated!\033[0m" + /dist/feishu_chatgpt diff --git a/readme.md b/readme.md index 694f82c6..c35086d9 100644 --- a/readme.md +++ b/readme.md @@ -206,7 +206,7 @@ docker run -d --name feishu-chatgpt -p 9000:9000 \ --env APP_ENCRYPT_KEY=xxx \ --env APP_VERIFICATION_TOKEN=xxx \ --env BOT_NAME=chatGpt \ ---env OPENAI_KEY=sk-xxx \ +--env OPENAI_KEY="sk-xxx1,sk-xxx2,sk-xxx3" \ feishu-chatgpt:latest ``` @@ -223,7 +223,7 @@ docker run -d --restart=always --name feishu-chatgpt2 -p 9000:9000 -v /etc/local --env APP_ENCRYPT_KEY=xxx \ --env APP_VERIFICATION_TOKEN=xxx \ --env BOT_NAME=chatGpt \ ---env OPENAI_KEY=sk-xxx \ +--env OPENAI_KEY="sk-xxx1,sk-xxx2,sk-xxx3" \ dockerproxy.com/leizhenpeng/feishu-chatgpt:latest ```