From d5e90be3db3ed2d8e256c6adc5d8ded4049535b7 Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 16 Dec 2023 18:12:56 -0700 Subject: [PATCH 01/61] #29: openai client --- pkg/providers/openai/openai.go | 88 ++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 pkg/providers/openai/openai.go diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai.go new file mode 100644 index 00000000..6e71e845 --- /dev/null +++ b/pkg/providers/openai/openai.go @@ -0,0 +1,88 @@ +package openai + +import ( + "fmt" + "net/http" + "io" + "bytes" +) + +type OpenAiClient struct { + apiKey string + baseURL string + http *http.Client +} + +func NewOpenAiClient(apiKey string) *OpenAiClient { + return &OpenAiClient{ + apiKey: apiKey, + baseURL: "https://api.openai.com/v1", + http: http.DefaultClient, + } +} + +func (c *OpenAiClient) SetBaseURL(baseURL string) { + c.baseURL = baseURL +} + +func (c *OpenAiClient) SetHTTPOpenAiClient(httpOpenAiClient *http.Client) { + c.http = httpOpenAiClient +} + +func (c *OpenAiClient) GetAPIKey() string { + return c.apiKey +} + +func (c *OpenAiClient) Get(endpoint string) (string, error) { + // Implement the logic to make a GET request to the OpenAI API + + return "", nil +} + +func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { + // Implement the logic to make a POST request to the OpenAI API + + // Create the full URL + url := c.baseURL + endpoint + + // Create a new request using http + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) + if err != nil { + return "", err + } + + // Set the headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.apiKey) + + // Send the request using http Client + resp, err := c.http.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return string(responseBody), nil +} + +// Add more methods to interact with OpenAI API + +func main() { + // Example usage of the OpenAI OpenAiClient + OpenAiClient := NewOpenAiClient("YOUR_API_KEY") + + // Call methods on the OpenAiClient to interact with the OpenAI API + // For example: + response, err := OpenAiClient.Get("/endpoints") + if err != nil { + fmt.Println("Error:", err) + return + } + + fmt.Println("Response:", response) +} \ No newline at end of file From 4282f551d076e609575292198ad12cc7a3ace092 Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 16 Dec 2023 18:28:41 -0700 Subject: [PATCH 02/61] #29: add sample --- pkg/providers/openai/openai.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai.go index 6e71e845..6f0add9f 100644 --- a/pkg/providers/openai/openai.go +++ b/pkg/providers/openai/openai.go @@ -78,7 +78,8 @@ func main() { // Call methods on the OpenAiClient to interact with the OpenAI API // For example: - response, err := OpenAiClient.Get("/endpoints") + payrload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) + response, err := OpenAiClient.Post("/chat", payrload) if err != nil { fmt.Println("Error:", err) return From 33c89308456d6c9b4261a86e0c2a9e950c5303d8 Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 16 Dec 2023 18:36:42 -0700 Subject: [PATCH 03/61] #29: Refactor OpenAI provider configuration --- pkg/providers/openai/chat.go | 53 ---------------------------------- pkg/providers/openai/openai.go | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 53 deletions(-) delete mode 100644 pkg/providers/openai/chat.go diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go deleted file mode 100644 index 064e7d46..00000000 --- a/pkg/providers/openai/chat.go +++ /dev/null @@ -1,53 +0,0 @@ -package openai - -type OpenAiProviderConfig struct { - Model string `json:"model" validate:"required,lowercase"` - Messages string `json:"messages" validate:"required"` // does this need to be updated to []string? - MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` - Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"` - TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` - N int `json:"n" validate:"omitempty,gte=1"` - Stream bool `json:"stream" validate:"omitempty, boolean"` - Stop interface{} `json:"stop"` - PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` - FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` - LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` - User interface{} `json:"user"` - Seed interface{} `json:"seed" validate:"omitempty,gte=0"` - Tools []string `json:"tools"` - ToolChoice interface{} `json:"tool_choice"` - ResponseFormat interface{} `json:"response_format"` -} - -var defaultMessage = `[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello!" - } - ]` - -// Provide the request body for OpenAI's ChatCompletion API -func OpenAiChatDefaultConfig() OpenAiProviderConfig { - return OpenAiProviderConfig{ - Model: "gpt-3.5-turbo", - Messages: defaultMessage, - MaxTokens: 100, - Temperature: 1, - TopP: 1, - N: 1, - Stream: false, - Stop: nil, - PresencePenalty: 0, - FrequencyPenalty: 0, - LogitBias: nil, - User: nil, - Seed: nil, - Tools: nil, - ToolChoice: nil, - ResponseFormat: nil, - } -} diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai.go index 6f0add9f..43cc46dc 100644 --- a/pkg/providers/openai/openai.go +++ b/pkg/providers/openai/openai.go @@ -7,9 +7,29 @@ import ( "bytes" ) +type OpenAiProviderConfig struct { + Model string `json:"model" validate:"required,lowercase"` + Messages string `json:"messages" validate:"required"` // does this need to be updated to []string? + MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` + Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"` + TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` + N int `json:"n" validate:"omitempty,gte=1"` + Stream bool `json:"stream" validate:"omitempty, boolean"` + Stop interface{} `json:"stop"` + PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` + FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` + LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` + User interface{} `json:"user"` + Seed interface{} `json:"seed" validate:"omitempty,gte=0"` + Tools []string `json:"tools"` + ToolChoice interface{} `json:"tool_choice"` + ResponseFormat interface{} `json:"response_format"` +} + type OpenAiClient struct { apiKey string baseURL string + params OpenAiProviderConfig http *http.Client } @@ -17,10 +37,42 @@ func NewOpenAiClient(apiKey string) *OpenAiClient { return &OpenAiClient{ apiKey: apiKey, baseURL: "https://api.openai.com/v1", + params: OpenAiChatDefaultConfig(), http: http.DefaultClient, } } +var defaultMessage = `[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ]` + +func OpenAiChatDefaultConfig() OpenAiProviderConfig { + return OpenAiProviderConfig{ + Model: "gpt-3.5-turbo", + Messages: defaultMessage, + MaxTokens: 100, + Temperature: 1, + TopP: 1, + N: 1, + Stream: false, + Stop: nil, + PresencePenalty: 0, + FrequencyPenalty: 0, + LogitBias: nil, + User: nil, + Seed: nil, + Tools: nil, + ToolChoice: nil, + ResponseFormat: nil, + } +} func (c *OpenAiClient) SetBaseURL(baseURL string) { c.baseURL = baseURL } From fbbeb7ed343d631029e5383130f46bd2303846f8 Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 16 Dec 2023 19:49:19 -0700 Subject: [PATCH 04/61] #29: Refactor OpenAiClient struct and methods --- pkg/providers/openai/openai.go | 78 ++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai.go index 43cc46dc..e83f4531 100644 --- a/pkg/providers/openai/openai.go +++ b/pkg/providers/openai/openai.go @@ -1,10 +1,17 @@ package openai import ( + "bytes" "fmt" - "net/http" "io" - "bytes" + "log/slog" + "encoding/json" + "context" + + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" + ) type OpenAiProviderConfig struct { @@ -27,18 +34,18 @@ type OpenAiProviderConfig struct { } type OpenAiClient struct { - apiKey string - baseURL string - params OpenAiProviderConfig - http *http.Client + apiKey string + baseURL string + params OpenAiProviderConfig + http *client.Client } func NewOpenAiClient(apiKey string) *OpenAiClient { return &OpenAiClient{ - apiKey: apiKey, - baseURL: "https://api.openai.com/v1", - params: OpenAiChatDefaultConfig(), - http: http.DefaultClient, + apiKey: apiKey, + baseURL: "https://api.openai.com/v1", + params: OpenAiChatDefaultConfig(), + http: HertzClient(), } } @@ -52,7 +59,17 @@ var defaultMessage = `[ "content": "Hello!" } ]` - + +func HertzClient() *client.Client { + + c, err := client.NewClient() + if err != nil { + slog.Error(err.Error()) + } + return c + +} + func OpenAiChatDefaultConfig() OpenAiProviderConfig { return OpenAiProviderConfig{ Model: "gpt-3.5-turbo", @@ -77,7 +94,7 @@ func (c *OpenAiClient) SetBaseURL(baseURL string) { c.baseURL = baseURL } -func (c *OpenAiClient) SetHTTPOpenAiClient(httpOpenAiClient *http.Client) { +func (c *OpenAiClient) SetHTTPOpenAiClient(httpOpenAiClient *client.Client) { c.http = httpOpenAiClient } @@ -94,32 +111,19 @@ func (c *OpenAiClient) Get(endpoint string) (string, error) { func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { // Implement the logic to make a POST request to the OpenAI API + req := &protocol.Request{} + res := &protocol.Response{} + // Create the full URL url := c.baseURL + endpoint - // Create a new request using http - req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) + req.Header.SetMethod(consts.MethodPost) + req.Header.SetContentTypeBytes([]byte("application/json")) + req.SetRequestURI(url) + req.SetBody(payload) + err = client.Do(context.Background(), req, res) if err != nil { - return "", err - } - - // Set the headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.apiKey) - - // Send the request using http Client - resp, err := c.http.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - return string(responseBody), nil + return } // Add more methods to interact with OpenAI API @@ -127,7 +131,7 @@ func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { func main() { // Example usage of the OpenAI OpenAiClient OpenAiClient := NewOpenAiClient("YOUR_API_KEY") - + // Call methods on the OpenAiClient to interact with the OpenAI API // For example: payrload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) @@ -136,6 +140,6 @@ func main() { fmt.Println("Error:", err) return } - + fmt.Println("Response:", response) -} \ No newline at end of file +} From 62cd933026f66e19bc79d387901f9bbfe69b3923 Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 16 Dec 2023 19:59:21 -0700 Subject: [PATCH 05/61] #29: Refactor OpenAiClient struct and methods --- pkg/providers/openai/openai.go | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai.go index e83f4531..1bbbc148 100644 --- a/pkg/providers/openai/openai.go +++ b/pkg/providers/openai/openai.go @@ -1,11 +1,8 @@ package openai import ( - "bytes" "fmt" - "io" "log/slog" - "encoding/json" "context" "github.com/cloudwego/hertz/pkg/app/client" @@ -102,13 +99,7 @@ func (c *OpenAiClient) GetAPIKey() string { return c.apiKey } -func (c *OpenAiClient) Get(endpoint string) (string, error) { - // Implement the logic to make a GET request to the OpenAI API - - return "", nil -} - -func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { +func (c *OpenAiClient) Post(endpoint string, payload []byte) ([]byte, error) { // Implement the logic to make a POST request to the OpenAI API req := &protocol.Request{} @@ -121,9 +112,14 @@ func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { req.Header.SetContentTypeBytes([]byte("application/json")) req.SetRequestURI(url) req.SetBody(payload) - err = client.Do(context.Background(), req, res) - if err != nil { - return + // Define the err variable + err := client.Do(context.Background(), req, res) + if err != nil { + slog.Error(err.Error()) + // Return nil and the error + return nil, err + } + return res.Body(), nil } // Add more methods to interact with OpenAI API @@ -134,8 +130,8 @@ func main() { // Call methods on the OpenAiClient to interact with the OpenAI API // For example: - payrload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) - response, err := OpenAiClient.Post("/chat", payrload) + payload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) + response, err := OpenAiClient.Post("/chat", payload) if err != nil { fmt.Println("Error:", err) return From da763190751aed031e1d07b689d8f842b175afdc Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 17 Dec 2023 15:06:56 -0700 Subject: [PATCH 06/61] #29: refactor --- pkg/buildAPIRequest.go | 1 + pkg/providers/openai/chat.go | 244 ++++++++++++++++++ pkg/providers/openai/index.go | 12 - .../openai/{openai.go => openai_temp} | 28 +- pkg/providers/openai/openaiclient.go | 165 ++++++++++++ 5 files changed, 426 insertions(+), 24 deletions(-) create mode 100644 pkg/providers/openai/chat.go delete mode 100644 pkg/providers/openai/index.go rename pkg/providers/openai/{openai.go => openai_temp} (87%) create mode 100644 pkg/providers/openai/openaiclient.go diff --git a/pkg/buildAPIRequest.go b/pkg/buildAPIRequest.go index f28d22b1..ff591446 100644 --- a/pkg/buildAPIRequest.go +++ b/pkg/buildAPIRequest.go @@ -8,6 +8,7 @@ import ( "fmt" "glide/pkg/providers" "glide/pkg/providers/openai" + ) type ProviderConfigs = pkg.ProviderConfigs diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go new file mode 100644 index 00000000..961ba72e --- /dev/null +++ b/pkg/providers/openai/chat.go @@ -0,0 +1,244 @@ +package openaiclient + + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "github.com/cloudwego/hertz/pkg/app/client" +) + +const ( + defaultChatModel = "gpt-3.5-turbo" +) + +// ChatRequest is a request to complete a chat completion.. +type ChatRequest struct { + Model string `json:"model"` + Messages []*ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + + // Function definitions to include in the request. + Functions []FunctionDefinition `json:"functions,omitempty"` + // FunctionCallBehavior is the behavior to use when calling functions. + // + // If a specific function should be invoked, use the format: + // `{"name": "my_function"}` + FunctionCallBehavior FunctionCallBehavior `json:"function_call,omitempty"` + + // StreamingFunc is a function to be called for each chunk of a streaming response. + // Return an error to stop streaming early. + StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` +} + +// ChatMessage is a message in a chat request. +type ChatMessage struct { + // The role of the author of this message. One of system, user, or assistant. + Role string `json:"role"` + // The content of the message. + Content string `json:"content"` + // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name string `json:"name,omitempty"` + + // FunctionCall represents a function call to be made in the message. + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +// ChatChoice is a choice in a chat response. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatUsage is the usage of a chat completion request. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatResponse is a response to a chat request. +type ChatResponse struct { + ID string `json:"id,omitempty"` + Created float64 `json:"created,omitempty"` + Choices []*ChatChoice `json:"choices,omitempty"` + Model string `json:"model,omitempty"` + Object string `json:"object,omitempty"` + Usage struct { + CompletionTokens float64 `json:"completion_tokens,omitempty"` + PromptTokens float64 `json:"prompt_tokens,omitempty"` + TotalTokens float64 `json:"total_tokens,omitempty"` + } `json:"usage,omitempty"` +} + +// StreamedChatResponsePayload is a chunk from the stream. +type StreamedChatResponsePayload struct { + ID string `json:"id,omitempty"` + Created float64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Object string `json:"object,omitempty"` + Choices []struct { + Index float64 `json:"index,omitempty"` + Delta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + } `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + } `json:"choices,omitempty"` +} + +// FunctionDefinition is a definition of a function that can be called by the model. +type FunctionDefinition struct { + // Name is the name of the function. + Name string `json:"name"` + // Description is a description of the function. + Description string `json:"description"` + // Parameters is a list of parameters for the function. + Parameters any `json:"parameters"` +} + +// FunctionCallBehavior is the behavior to use when calling functions. +type FunctionCallBehavior string + +const ( + // FunctionCallBehaviorUnspecified is the empty string. + FunctionCallBehaviorUnspecified FunctionCallBehavior = "" + // FunctionCallBehaviorNone will not call any functions. + FunctionCallBehaviorNone FunctionCallBehavior = "none" + // FunctionCallBehaviorAuto will call functions automatically. + FunctionCallBehaviorAuto FunctionCallBehavior = "auto" +) + +// FunctionCall is a call to a function. +type FunctionCall struct { + // Name is the name of the function to call. + Name string `json:"name"` + // Arguments is the set of arguments to pass to the function. + Arguments string `json:"arguments"` +} + +func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { + if payload.StreamingFunc != nil { + payload.Stream = true + } + // Build request payload + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + // Build request + body := bytes.NewReader(payloadBytes) + if c.baseURL == "" { + c.baseURL = defaultBaseURL + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", c.Model), body) + if err != nil { + return nil, err + } + + c.setHeaders(req) + + // Send request + r, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer r.Body.Close() + + if r.StatusCode != http.StatusOK { + msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) + + // No need to check the error here: if it fails, we'll just return the + // status code. + var errResp errorMessage + if err := json.NewDecoder(r.Body).Decode(&errResp); err != nil { + return nil, errors.New(msg) // nolint:goerr113 + } + + return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113 + } + if payload.StreamingFunc != nil { + return parseStreamingChatResponse(ctx, r, payload) + } + // Parse response + var response ChatResponse + return &response, json.NewDecoder(r.Body).Decode(&response) +} + +func parseStreamingChatResponse(ctx context.Context, r *http.Response, payload *ChatRequest) (*ChatResponse, error) { //nolint:cyclop,lll + scanner := bufio.NewScanner(r.Body) + responseChan := make(chan StreamedChatResponsePayload) + go func() { + defer close(responseChan) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + if !strings.HasPrefix(line, "data:") { + log.Fatalf("unexpected line: %v", line) + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + return + } + var streamPayload StreamedChatResponsePayload + err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload) + if err != nil { + log.Fatalf("failed to decode stream payload: %v", err) + } + responseChan <- streamPayload + } + if err := scanner.Err(); err != nil { + log.Println("issue scanning response:", err) + } + }() + // Parse response + response := ChatResponse{ + Choices: []*ChatChoice{ + {}, + }, + } + + for streamResponse := range responseChan { + if len(streamResponse.Choices) == 0 { + continue + } + chunk := []byte(streamResponse.Choices[0].Delta.Content) + response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content + response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason + if streamResponse.Choices[0].Delta.FunctionCall != nil { + if response.Choices[0].Message.FunctionCall == nil { + response.Choices[0].Message.FunctionCall = streamResponse.Choices[0].Delta.FunctionCall + } else { + response.Choices[0].Message.FunctionCall.Arguments += streamResponse.Choices[0].Delta.FunctionCall.Arguments + } + chunk, _ = json.Marshal(response.Choices[0].Message.FunctionCall) // nolint:errchkjson + } + + if payload.StreamingFunc != nil { + err := payload.StreamingFunc(ctx, chunk) + if err != nil { + return nil, fmt.Errorf("streaming func returned an error: %w", err) + } + } + } + return &response, nil +} \ No newline at end of file diff --git a/pkg/providers/openai/index.go b/pkg/providers/openai/index.go deleted file mode 100644 index d7a7a60c..00000000 --- a/pkg/providers/openai/index.go +++ /dev/null @@ -1,12 +0,0 @@ -package openai - -import ( - "glide/pkg/providers" -) - - -// TODO: this needs to be imported into buildAPIRequest.go -var OpenAIConfig = pkg.ProviderConfigs{ - "api": OpenAIAPIConfig, - "chat": OpenAiChatDefaultConfig, -} diff --git a/pkg/providers/openai/openai.go b/pkg/providers/openai/openai_temp similarity index 87% rename from pkg/providers/openai/openai.go rename to pkg/providers/openai/openai_temp index 1bbbc148..d8dee11b 100644 --- a/pkg/providers/openai/openai.go +++ b/pkg/providers/openai/openai_temp @@ -1,8 +1,11 @@ package openai import ( + "bytes" "fmt" + "io" "log/slog" + "encoding/json" "context" "github.com/cloudwego/hertz/pkg/app/client" @@ -99,7 +102,13 @@ func (c *OpenAiClient) GetAPIKey() string { return c.apiKey } -func (c *OpenAiClient) Post(endpoint string, payload []byte) ([]byte, error) { +func (c *OpenAiClient) Get(endpoint string) (string, error) { + // Implement the logic to make a GET request to the OpenAI API + + return "", nil +} + +func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { // Implement the logic to make a POST request to the OpenAI API req := &protocol.Request{} @@ -112,14 +121,9 @@ func (c *OpenAiClient) Post(endpoint string, payload []byte) ([]byte, error) { req.Header.SetContentTypeBytes([]byte("application/json")) req.SetRequestURI(url) req.SetBody(payload) - // Define the err variable - err := client.Do(context.Background(), req, res) - if err != nil { - slog.Error(err.Error()) - // Return nil and the error - return nil, err - } - return res.Body(), nil + err = client.Do(context.Background(), req, res) + if err != nil { + return } // Add more methods to interact with OpenAI API @@ -130,12 +134,12 @@ func main() { // Call methods on the OpenAiClient to interact with the OpenAI API // For example: - payload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) - response, err := OpenAiClient.Post("/chat", payload) + payrload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) + response, err := OpenAiClient.Post("/chat", payrload) if err != nil { fmt.Println("Error:", err) return } fmt.Println("Response:", response) -} +} \ No newline at end of file diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go new file mode 100644 index 00000000..6de7e9ed --- /dev/null +++ b/pkg/providers/openai/openaiclient.go @@ -0,0 +1,165 @@ +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "log/slog" + + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/protocol" + +) + +const ( + defaultBaseURL = "https://api.openai.com/v1" + defaultFunctionCallBehavior = "auto" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ErrEmptyResponse = errors.New("empty response") + +type APIType string + +const ( + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" +) + +// Client is a client for the OpenAI API. +type Client struct { + token string + Model string + baseURL string + organization string + apiType APIType + httpClient *client.Client + + // required when APIType is APITypeAzure or APITypeAzureAD + apiVersion string + embeddingsModel string +} + +// Option is an option for the OpenAI client. +type Option func(*Client) error + +// Doer performs a HTTP request. +type Doer interface { + Do(req protocol.Request) (protocol.Response, error) +} + +// New returns a new OpenAI client. +func New(token string, model string, baseURL string, organization string, + apiType APIType, apiVersion string, httpClient Doer, embeddingsModel string, + opts ...Option, +) (*Client, error) { + c := &Client{ + token: token, + Model: model, + embeddingsModel: embeddingsModel, + baseURL: baseURL, + organization: organization, + apiType: apiType, + apiVersion: apiVersion, + httpClient: HertzClient(), + } + + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + + return c, nil +} + +func HertzClient() *client.Client { + + c, err := client.NewClient() + if err != nil { + slog.Error(err.Error()) + } + return c + +} + +// Completion is a completion. +type Completion struct { + Text string `json:"text"` +} + +// CreateCompletion creates a completion. +func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*Completion, error) { + resp, err := c.createCompletion(ctx, r) + if err != nil { + return nil, err + } + if len(resp.Choices) == 0 { + return nil, ErrEmptyResponse + } + return &Completion{ + Text: resp.Choices[0].Message.Content, + }, nil +} + + +// CreateChat creates chat request. +func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { + if r.Model == "" { + if c.Model == "" { + r.Model = defaultChatModel + } else { + r.Model = c.Model + } + } + if r.FunctionCallBehavior == "" && len(r.Functions) > 0 { + r.FunctionCallBehavior = defaultFunctionCallBehavior + } + resp, err := c.createChat(ctx, r) + if err != nil { + return nil, err + } + if len(resp.Choices) == 0 { + return nil, ErrEmptyResponse + } + return resp, nil +} + +func IsAzure(apiType APIType) bool { + return apiType == APITypeAzure || apiType == APITypeAzureAD +} + +func (c *Client) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { + req.Header.Set("Authorization", "Bearer "+c.token) + } else { + req.Header.Set("api-key", c.token) + } + if c.organization != "" { + req.Header.Set("OpenAI-Organization", c.organization) + } +} + +func (c *Client) buildURL(suffix string, model string) string { + if IsAzure(c.apiType) { + return c.buildAzureURL(suffix, model) + } + + // open ai implement: + return fmt.Sprintf("%s%s", c.baseURL, suffix) +} + +func (c *Client) buildAzureURL(suffix string, model string) string { + baseURL := c.baseURL + baseURL = strings.TrimRight(baseURL, "/") + + // azure example url: + // /openai/deployments/{model}/chat/completions?api-version={api_version} + return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", + baseURL, model, suffix, c.apiVersion, + ) +} \ No newline at end of file From e3266ac8173670ab4beb726267f2ab4a371a2eee Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 18 Dec 2023 17:50:21 -0700 Subject: [PATCH 07/61] 29: chat converted to Hertz --- pkg/providers/openai/chat.go | 43 ++++++++----------- .../openai/{openai_temp => openai_temp.go} | 0 2 files changed, 19 insertions(+), 24 deletions(-) rename pkg/providers/openai/{openai_temp => openai_temp.go} (100%) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 961ba72e..c48fe269 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -1,17 +1,18 @@ -package openaiclient - +package openai import ( "bufio" "bytes" "context" "encoding/json" - "errors" "fmt" "log" "net/http" "strings" + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" ) const ( @@ -144,46 +145,40 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes } // Build request - body := bytes.NewReader(payloadBytes) if c.baseURL == "" { c.baseURL = defaultBaseURL } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", c.Model), body) - if err != nil { - return nil, err - } - c.setHeaders(req) + req := &protocol.Request{} + res := &protocol.Response{} + req.Header.SetMethod(consts.MethodPost) + req.SetRequestURI(c.buildURL("/chat/completions", c.Model)) + req.SetBody(payloadBytes) + // Send request - r, err := c.httpClient.Do(req) + err = client.Do(ctx, req, res) if err != nil { return nil, err } - defer r.Body.Close() - if r.StatusCode != http.StatusOK { - msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) + defer res.ConnectionClose() // replaced r.Body.Close() - // No need to check the error here: if it fails, we'll just return the - // status code. - var errResp errorMessage - if err := json.NewDecoder(r.Body).Decode(&errResp); err != nil { - return nil, errors.New(msg) // nolint:goerr113 - } + if res.StatusCode() != http.StatusOK { + msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode()) - return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113 + return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 } if payload.StreamingFunc != nil { - return parseStreamingChatResponse(ctx, r, payload) + return parseStreamingChatResponse(ctx, res, payload) } // Parse response var response ChatResponse - return &response, json.NewDecoder(r.Body).Decode(&response) + return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) } -func parseStreamingChatResponse(ctx context.Context, r *http.Response, payload *ChatRequest) (*ChatResponse, error) { //nolint:cyclop,lll - scanner := bufio.NewScanner(r.Body) +func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, payload *ChatRequest) (*ChatResponse, error) { + scanner := bufio.NewScanner(bytes.NewReader(r.Body())) responseChan := make(chan StreamedChatResponsePayload) go func() { defer close(responseChan) diff --git a/pkg/providers/openai/openai_temp b/pkg/providers/openai/openai_temp.go similarity index 100% rename from pkg/providers/openai/openai_temp rename to pkg/providers/openai/openai_temp.go From f34604b953feba3293017faf7640e810567dcb08 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 18 Dec 2023 18:31:01 -0700 Subject: [PATCH 08/61] #29: Update OpenAI provider configuration and chat request validation --- pkg/providers/openai/chat.go | 30 +++++++++++++--------------- pkg/providers/openai/openai_temp.go | 12 ++--------- pkg/providers/openai/openaiclient.go | 21 ++----------------- 3 files changed, 18 insertions(+), 45 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index c48fe269..897099c6 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -21,24 +21,22 @@ const ( // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { - Model string `json:"model"` - Messages []*ChatMessage `json:"messages"` + Model string `json:"model" validate:"required,lowercase"` + Messages []*ChatMessage `json:"messages" validate:"required"` Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` + TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` + N int `json:"n,omitempty" validate:"omitempty,gte=1"` StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - - // Function definitions to include in the request. - Functions []FunctionDefinition `json:"functions,omitempty"` - // FunctionCallBehavior is the behavior to use when calling functions. - // - // If a specific function should be invoked, use the format: - // `{"name": "my_function"}` - FunctionCallBehavior FunctionCallBehavior `json:"function_call,omitempty"` + Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` + User interface{} `json:"user,omitempty"` + Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` // StreamingFunc is a function to be called for each chunk of a streaming response. // Return an error to stop streaming early. diff --git a/pkg/providers/openai/openai_temp.go b/pkg/providers/openai/openai_temp.go index d8dee11b..04a0845a 100644 --- a/pkg/providers/openai/openai_temp.go +++ b/pkg/providers/openai/openai_temp.go @@ -16,9 +16,9 @@ import ( type OpenAiProviderConfig struct { Model string `json:"model" validate:"required,lowercase"` - Messages string `json:"messages" validate:"required"` // does this need to be updated to []string? + Messages []*ChatMessage `json:"messages" validate:"required"` // does this need to be updated to []string? MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` - Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"` + Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=2"` TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` N int `json:"n" validate:"omitempty,gte=1"` Stream bool `json:"stream" validate:"omitempty, boolean"` @@ -60,15 +60,7 @@ var defaultMessage = `[ } ]` -func HertzClient() *client.Client { - c, err := client.NewClient() - if err != nil { - slog.Error(err.Error()) - } - return c - -} func OpenAiChatDefaultConfig() OpenAiProviderConfig { return OpenAiProviderConfig{ diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 6de7e9ed..59c3698b 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net/http" "strings" "log/slog" @@ -91,20 +90,6 @@ type Completion struct { Text string `json:"text"` } -// CreateCompletion creates a completion. -func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*Completion, error) { - resp, err := c.createCompletion(ctx, r) - if err != nil { - return nil, err - } - if len(resp.Choices) == 0 { - return nil, ErrEmptyResponse - } - return &Completion{ - Text: resp.Choices[0].Message.Content, - }, nil -} - // CreateChat creates chat request. func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { @@ -115,9 +100,7 @@ func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse, r.Model = c.Model } } - if r.FunctionCallBehavior == "" && len(r.Functions) > 0 { - r.FunctionCallBehavior = defaultFunctionCallBehavior - } + resp, err := c.createChat(ctx, r) if err != nil { return nil, err @@ -132,7 +115,7 @@ func IsAzure(apiType APIType) bool { return apiType == APITypeAzure || apiType == APITypeAzureAD } -func (c *Client) setHeaders(req *http.Request) { +func (c *Client) setHeaders(req *protocol.Request) { req.Header.Set("Content-Type", "application/json") if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { req.Header.Set("Authorization", "Bearer "+c.token) From 131122815a42403d8230933a935e909b742394f1 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 18 Dec 2023 18:45:24 -0700 Subject: [PATCH 09/61] #29: clean up --- go.mod | 9 +- go.sum | 18 ---- pkg/buildAPIRequest.go | 60 ------------ pkg/providers/openai/api.go | 30 ------ pkg/providers/openai/chat.go | 23 +++-- pkg/providers/openai/openai_temp.go | 137 ---------------------------- pkg/providers/types.go | 3 - 7 files changed, 12 insertions(+), 268 deletions(-) delete mode 100644 pkg/buildAPIRequest.go delete mode 100644 pkg/providers/openai/api.go delete mode 100644 pkg/providers/openai/openai_temp.go delete mode 100644 pkg/providers/types.go diff --git a/go.mod b/go.mod index 5b909391..0f7e8590 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 - github.com/go-playground/validator/v10 v10.16.0 github.com/spf13/cobra v1.8.0 go.uber.org/multierr v1.11.0 ) @@ -16,25 +15,19 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cloudwego/netpoll v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/henrylee2cn/ameda v1.4.10 // indirect github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect - github.com/leodido/go-urn v1.2.4 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.8.2 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect - golang.org/x/text v0.8.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index 84649808..f0b1d15f 100644 --- a/go.sum +++ b/go.sum @@ -20,16 +20,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= -github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -48,8 +38,6 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -89,11 +77,7 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -101,8 +85,6 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/buildAPIRequest.go b/pkg/buildAPIRequest.go deleted file mode 100644 index ff591446..00000000 --- a/pkg/buildAPIRequest.go +++ /dev/null @@ -1,60 +0,0 @@ -// this file contains the BuildAPIRequest function which takes in the provider name, params map, and mode and returns the providerConfig map and error -// The providerConfig map can be used to build the API request to the provider -package pkg - -import ( - "errors" - "github.com/go-playground/validator/v10" - "fmt" - "glide/pkg/providers" - "glide/pkg/providers/openai" - -) - -type ProviderConfigs = pkg.ProviderConfigs - -// Initialize configList - -var configList = map[string]interface{}{ - "openai": openai.OpenAIConfig, -} - -// Create a new validator instance -var validate *validator.Validate = validator.New() - - -func BuildAPIRequest(provider string, params map[string]string, mode string) (interface{}, error) { - // provider is the name of the provider, e.g. "openai", params is the map of parameters from the client, - // mode is the mode of the provider, e.g. "chat", configList is the list of provider configurations - - - var providerConfig map[string]interface{} - if config, ok := configList[provider].(ProviderConfigs); ok { - if modeConfig, ok := config[mode].(map[string]interface{}); ok { - providerConfig = modeConfig - } -} - - // If the provider is not supported, return an error - if providerConfig == nil { - return nil, errors.New("unsupported provider") - } - - - // Build the providerConfig map by iterating over the keys in the providerConfig map and checking if the key exists in the params map - - for key := range providerConfig { - if value, exists := params[key]; exists { - providerConfig[key] = value - } - } - - // Validate the providerConfig map using the validator package - err := validate.Struct(providerConfig) - if err != nil { - // Handle validation error - return nil, fmt.Errorf("validation error: %v", err) - } - // If everything is fine, return the providerConfig and nil error - return providerConfig, nil -} diff --git a/pkg/providers/openai/api.go b/pkg/providers/openai/api.go deleted file mode 100644 index 4a7a013b..00000000 --- a/pkg/providers/openai/api.go +++ /dev/null @@ -1,30 +0,0 @@ -package openai - -import ( - "fmt" - "net/http" -) - - -// provides the base URL and headers for the OpenAI API -type ProviderAPIConfig struct { - BaseURL string - Headers func(string) http.Header - Complete string - Chat string - Embed string -} - -func OpenAIAPIConfig(APIKey string) *ProviderAPIConfig { - return &ProviderAPIConfig{ - BaseURL: "https://api.openai.com/v1", - Headers: func(APIKey string) http.Header { - headers := make(http.Header) - headers.Set("Authorization", fmt.Sprintf("Bearer %s", APIKey)) - return headers - }, - Complete: "/completions", - Chat: "/chat/completions", - Embed: "/embeddings", - } -} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 897099c6..3c482f72 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -21,16 +21,16 @@ const ( // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { - Model string `json:"model" validate:"required,lowercase"` - Messages []*ChatMessage `json:"messages" validate:"required"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` - N int `json:"n,omitempty" validate:"omitempty,gte=1"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` + Model string `json:"model" validate:"required,lowercase"` + Messages []*ChatMessage `json:"messages" validate:"required"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` + N int `json:"n,omitempty" validate:"omitempty,gte=1"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` User interface{} `json:"user,omitempty"` Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` @@ -153,7 +153,6 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes req.SetRequestURI(c.buildURL("/chat/completions", c.Model)) req.SetBody(payloadBytes) - // Send request err = client.Do(ctx, req, res) if err != nil { @@ -234,4 +233,4 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo } } return &response, nil -} \ No newline at end of file +} diff --git a/pkg/providers/openai/openai_temp.go b/pkg/providers/openai/openai_temp.go deleted file mode 100644 index 04a0845a..00000000 --- a/pkg/providers/openai/openai_temp.go +++ /dev/null @@ -1,137 +0,0 @@ -package openai - -import ( - "bytes" - "fmt" - "io" - "log/slog" - "encoding/json" - "context" - - "github.com/cloudwego/hertz/pkg/app/client" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/protocol/consts" - -) - -type OpenAiProviderConfig struct { - Model string `json:"model" validate:"required,lowercase"` - Messages []*ChatMessage `json:"messages" validate:"required"` // does this need to be updated to []string? - MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` - Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=2"` - TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` - N int `json:"n" validate:"omitempty,gte=1"` - Stream bool `json:"stream" validate:"omitempty, boolean"` - Stop interface{} `json:"stop"` - PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` - FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` - LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` - User interface{} `json:"user"` - Seed interface{} `json:"seed" validate:"omitempty,gte=0"` - Tools []string `json:"tools"` - ToolChoice interface{} `json:"tool_choice"` - ResponseFormat interface{} `json:"response_format"` -} - -type OpenAiClient struct { - apiKey string - baseURL string - params OpenAiProviderConfig - http *client.Client -} - -func NewOpenAiClient(apiKey string) *OpenAiClient { - return &OpenAiClient{ - apiKey: apiKey, - baseURL: "https://api.openai.com/v1", - params: OpenAiChatDefaultConfig(), - http: HertzClient(), - } -} - -var defaultMessage = `[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello!" - } - ]` - - - -func OpenAiChatDefaultConfig() OpenAiProviderConfig { - return OpenAiProviderConfig{ - Model: "gpt-3.5-turbo", - Messages: defaultMessage, - MaxTokens: 100, - Temperature: 1, - TopP: 1, - N: 1, - Stream: false, - Stop: nil, - PresencePenalty: 0, - FrequencyPenalty: 0, - LogitBias: nil, - User: nil, - Seed: nil, - Tools: nil, - ToolChoice: nil, - ResponseFormat: nil, - } -} -func (c *OpenAiClient) SetBaseURL(baseURL string) { - c.baseURL = baseURL -} - -func (c *OpenAiClient) SetHTTPOpenAiClient(httpOpenAiClient *client.Client) { - c.http = httpOpenAiClient -} - -func (c *OpenAiClient) GetAPIKey() string { - return c.apiKey -} - -func (c *OpenAiClient) Get(endpoint string) (string, error) { - // Implement the logic to make a GET request to the OpenAI API - - return "", nil -} - -func (c *OpenAiClient) Post(endpoint string, payload []byte) (string, error) { - // Implement the logic to make a POST request to the OpenAI API - - req := &protocol.Request{} - res := &protocol.Response{} - - // Create the full URL - url := c.baseURL + endpoint - - req.Header.SetMethod(consts.MethodPost) - req.Header.SetContentTypeBytes([]byte("application/json")) - req.SetRequestURI(url) - req.SetBody(payload) - err = client.Do(context.Background(), req, res) - if err != nil { - return -} - -// Add more methods to interact with OpenAI API - -func main() { - // Example usage of the OpenAI OpenAiClient - OpenAiClient := NewOpenAiClient("YOUR_API_KEY") - - // Call methods on the OpenAiClient to interact with the OpenAI API - // For example: - payrload := []byte(`{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}`) - response, err := OpenAiClient.Post("/chat", payrload) - if err != nil { - fmt.Println("Error:", err) - return - } - - fmt.Println("Response:", response) -} \ No newline at end of file diff --git a/pkg/providers/types.go b/pkg/providers/types.go deleted file mode 100644 index b27094d5..00000000 --- a/pkg/providers/types.go +++ /dev/null @@ -1,3 +0,0 @@ -package pkg - -type ProviderConfigs map[string]interface{} From 08036df3ff447482f6014fd770636412803587e2 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 18 Dec 2023 18:50:08 -0700 Subject: [PATCH 10/61] #29: Remove unused code and refactor parseStreamingChatResponse function --- pkg/providers/openai/chat.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 3c482f72..f126b8fa 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -53,8 +53,6 @@ type ChatMessage struct { // with a maximum length of 64 characters. Name string `json:"name,omitempty"` - // FunctionCall represents a function call to be made in the message. - FunctionCall *FunctionCall `json:"function_call,omitempty"` } // ChatChoice is a choice in a chat response. @@ -202,6 +200,7 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo log.Println("issue scanning response:", err) } }() + // Parse response response := ChatResponse{ Choices: []*ChatChoice{ @@ -216,14 +215,6 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo chunk := []byte(streamResponse.Choices[0].Delta.Content) response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason - if streamResponse.Choices[0].Delta.FunctionCall != nil { - if response.Choices[0].Message.FunctionCall == nil { - response.Choices[0].Message.FunctionCall = streamResponse.Choices[0].Delta.FunctionCall - } else { - response.Choices[0].Message.FunctionCall.Arguments += streamResponse.Choices[0].Delta.FunctionCall.Arguments - } - chunk, _ = json.Marshal(response.Choices[0].Message.FunctionCall) // nolint:errchkjson - } if payload.StreamingFunc != nil { err := payload.StreamingFunc(ctx, chunk) @@ -233,4 +224,4 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo } } return &response, nil -} +} \ No newline at end of file From c91a81d76f80dd586b15050925b6fc846ed88b71 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 07:59:30 -0700 Subject: [PATCH 11/61] #29: Update dependencies in go.mod and go.sum files --- go.mod | 1 + go.sum | 3 ++ pkg/providers/openai/chat.go | 36 ++---------------------- pkg/providers/openai/openaiclient.go | 41 ++++++++++++++++++++-------- pkg/providers/types.go | 19 +++++++++++++ 5 files changed, 56 insertions(+), 44 deletions(-) create mode 100644 pkg/providers/types.go diff --git a/go.mod b/go.mod index 0f7e8590..a1d1abd3 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/cloudwego/hertz v0.7.3 github.com/spf13/cobra v1.8.0 go.uber.org/multierr v1.11.0 + gopkg.in/yaml.v2 v2.4.0 ) require ( diff --git a/go.sum b/go.sum index f0b1d15f..92f974c8 100644 --- a/go.sum +++ b/go.sum @@ -91,8 +91,11 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index f126b8fa..27da6706 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -10,7 +10,6 @@ import ( "net/http" "strings" - "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -94,42 +93,11 @@ type StreamedChatResponsePayload struct { Delta struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` } `json:"delta,omitempty"` FinishReason string `json:"finish_reason,omitempty"` } `json:"choices,omitempty"` } -// FunctionDefinition is a definition of a function that can be called by the model. -type FunctionDefinition struct { - // Name is the name of the function. - Name string `json:"name"` - // Description is a description of the function. - Description string `json:"description"` - // Parameters is a list of parameters for the function. - Parameters any `json:"parameters"` -} - -// FunctionCallBehavior is the behavior to use when calling functions. -type FunctionCallBehavior string - -const ( - // FunctionCallBehaviorUnspecified is the empty string. - FunctionCallBehaviorUnspecified FunctionCallBehavior = "" - // FunctionCallBehaviorNone will not call any functions. - FunctionCallBehaviorNone FunctionCallBehavior = "none" - // FunctionCallBehaviorAuto will call functions automatically. - FunctionCallBehaviorAuto FunctionCallBehavior = "auto" -) - -// FunctionCall is a call to a function. -type FunctionCall struct { - // Name is the name of the function to call. - Name string `json:"name"` - // Arguments is the set of arguments to pass to the function. - Arguments string `json:"arguments"` -} - func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { if payload.StreamingFunc != nil { payload.Stream = true @@ -151,8 +119,10 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes req.SetRequestURI(c.buildURL("/chat/completions", c.Model)) req.SetBody(payloadBytes) + c.setHeaders(req) // sets additional headers + // Send request - err = client.Do(ctx, req, res) + err = c.httpClient.Do(ctx, req, res) //*client.Client if err != nil { return nil, err } diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 59c3698b..9360363f 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -6,6 +6,10 @@ import ( "fmt" "strings" "log/slog" + "gopkg.in/yaml.v2" + "os" + + "Glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/protocol" @@ -30,7 +34,7 @@ const ( // Client is a client for the OpenAI API. type Client struct { - token string + apiKey string Model string baseURL string organization string @@ -45,25 +49,40 @@ type Client struct { // Option is an option for the OpenAI client. type Option func(*Client) error -// Doer performs a HTTP request. -type Doer interface { - Do(req protocol.Request) (protocol.Response, error) + +func Init(c *Client) (*client.Client, error) { + // initializes the client + + // Read the YAML file + data, err := os.ReadFile("path/to/file.yaml") + if err != nil { + slog.Error("Failed to read file: %v", err) + } + + // Unmarshal the YAML data into your struct + var config GatewayConfig + err = yaml.Unmarshal(data, &config) + if err != nil { + slog.Error("Failed to unmarshal YAML: %v", err) + } + } + // New returns a new OpenAI client. -func New(token string, model string, baseURL string, organization string, - apiType APIType, apiVersion string, httpClient Doer, embeddingsModel string, +func New(apiKey string, model string, baseURL string, organization string, + apiType APIType, apiVersion string, httpClient *client.Client, embeddingsModel string, opts ...Option, ) (*Client, error) { c := &Client{ - token: token, + apiKey: apiKey, Model: model, embeddingsModel: embeddingsModel, baseURL: baseURL, organization: organization, apiType: apiType, apiVersion: apiVersion, - httpClient: HertzClient(), + httpClient: HttpClient(), } for _, opt := range opts { @@ -75,7 +94,7 @@ func New(token string, model string, baseURL string, organization string, return c, nil } -func HertzClient() *client.Client { +func HttpClient() *client.Client { c, err := client.NewClient() if err != nil { @@ -118,9 +137,9 @@ func IsAzure(apiType APIType) bool { func (c *Client) setHeaders(req *protocol.Request) { req.Header.Set("Content-Type", "application/json") if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { - req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Authorization", "Bearer "+c.apiKey) } else { - req.Header.Set("api-key", c.token) + req.Header.Set("api-key", c.apiKey) } if c.organization != "" { req.Header.Set("OpenAI-Organization", c.organization) diff --git a/pkg/providers/types.go b/pkg/providers/types.go new file mode 100644 index 00000000..456b95e3 --- /dev/null +++ b/pkg/providers/types.go @@ -0,0 +1,19 @@ +package providers +type GatewayConfig struct { + Pools []Pool `yaml:"pools"` +} + +type Pool struct { + Name string `yaml:"name"` + Balancing string `yaml:"balancing"` + Providers []Provider `yaml:"providers"` +} + +type Provider struct { + Name string `yaml:"name"` + Provider string `yaml:"provider"` + Model string `yaml:"model"` + APIKey string `yaml:"api_key"` + TimeoutMs int `yaml:"timeout_ms,omitempty"` + DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` +} \ No newline at end of file From dd4985e6aeaf57f60d7bb55c024a4a959b0adf76 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 08:04:36 -0700 Subject: [PATCH 12/61] #29: build client init --- pkg/providers/openai/openaiclient.go | 21 ++++++++++++++++++++- pkg/providers/types.go | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 9360363f..a69bae56 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -9,7 +9,7 @@ import ( "gopkg.in/yaml.v2" "os" - "Glide/pkg/providers" + //"Glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/protocol" @@ -32,6 +32,25 @@ const ( APITypeAzureAD APIType = "AZURE_AD" ) +type GatewayConfig struct { + Pools []Pool `yaml:"pools"` +} + +type Pool struct { + Name string `yaml:"name"` + Balancing string `yaml:"balancing"` + Providers []Provider `yaml:"providers"` +} + +type Provider struct { + Name string `yaml:"name"` + Provider string `yaml:"provider"` + Model string `yaml:"model"` + APIKey string `yaml:"api_key"` + TimeoutMs int `yaml:"timeout_ms,omitempty"` + DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` +} + // Client is a client for the OpenAI API. type Client struct { apiKey string diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 456b95e3..0143294c 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,4 +1,4 @@ -package providers +package provider type GatewayConfig struct { Pools []Pool `yaml:"pools"` } From 612842301efa718375ac2f9e6ead7719db39df56 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 15:31:14 -0700 Subject: [PATCH 13/61] #29: CreatChatRequest --- pkg/providers/openai/openaiclient.go | 174 ++++++++++++++++++++------- pkg/providers/types.go | 8 ++ 2 files changed, 141 insertions(+), 41 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index a69bae56..2f741eeb 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -8,6 +8,8 @@ import ( "log/slog" "gopkg.in/yaml.v2" "os" + "json" + "reflect" //"Glide/pkg/providers" @@ -18,12 +20,21 @@ import ( const ( defaultBaseURL = "https://api.openai.com/v1" + defaultOrganization = "" defaultFunctionCallBehavior = "auto" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. -var ErrEmptyResponse = errors.New("empty response") - +var ( + ErrEmptyResponse = errors.New("empty response") + requestBody struct { + Message []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + MessageHistory []string `json:"messageHistory"` + } +) type APIType string const ( @@ -46,15 +57,16 @@ type Provider struct { Name string `yaml:"name"` Provider string `yaml:"provider"` Model string `yaml:"model"` - APIKey string `yaml:"api_key"` + ApiKey string `yaml:"api_key"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } + + // Client is a client for the OpenAI API. type Client struct { - apiKey string - Model string + Provider Provider baseURL string organization string apiType APIType @@ -69,11 +81,11 @@ type Client struct { type Option func(*Client) error -func Init(c *Client) (*client.Client, error) { - // initializes the client +func (c *Client) Init(poolName string, modelName string, providerName string) (*Client, error) { + // Returns a []*Client of OpenAI // Read the YAML file - data, err := os.ReadFile("path/to/file.yaml") + data, err := os.ReadFile("config.yaml") if err != nil { slog.Error("Failed to read file: %v", err) } @@ -85,34 +97,47 @@ func Init(c *Client) (*client.Client, error) { slog.Error("Failed to unmarshal YAML: %v", err) } -} + // Find the pool with the specified name + var selectedPool *Pool + for _, pool := range config.Pools { + if pool.Name == poolName { + selectedPool = &pool + break + } + } + // Check if the pool was found + if selectedPool == nil { + slog.Error("pool '%s' not found", poolName) + } -// New returns a new OpenAI client. -func New(apiKey string, model string, baseURL string, organization string, - apiType APIType, apiVersion string, httpClient *client.Client, embeddingsModel string, - opts ...Option, -) (*Client, error) { - c := &Client{ - apiKey: apiKey, - Model: model, - embeddingsModel: embeddingsModel, - baseURL: baseURL, - organization: organization, - apiType: apiType, - apiVersion: apiVersion, - httpClient: HttpClient(), - } - - for _, opt := range opts { - if err := opt(c); err != nil { - return nil, err + // Find the OpenAI provider in the selected pool with the specified model + var selectedProvider *Provider + for _, provider := range selectedPool.Providers { + if provider.Name == providerName && provider.Model == modelName { + selectedProvider = &provider + break + } + } + + // Check if the provider was found + if selectedProvider == nil { + slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) + } + + // Create clients for each OpenAI provider + client := &Client{ + Provider: *selectedProvider, + organization: defaultOrganization, + apiType: APITypeOpenAI, + httpClient: HttpClient(), } - } - return c, nil + return client, nil + } + func HttpClient() *client.Client { c, err := client.NewClient() @@ -123,19 +148,74 @@ func HttpClient() *client.Client { } -// Completion is a completion. -type Completion struct { - Text string `json:"text"` -} +func (c *Client) CreateChatRequest(message []byte) *ChatRequest { + + + err := json.Unmarshal(message, &requestBody) + if err != nil { + slog.Error("Error:", err) + return nil + } + + var messages []*ChatMessage + for _, msg := range requestBody.Message { + chatMsg := &ChatMessage{ + Role: msg.Role, + Content: msg.Content, + } + if msg.Role == "user" { + chatMsg.Content += " " + strings.Join(requestBody.MessageHistory, " ") + } + messages = append(messages, chatMsg) + } + + // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value + + chatRequest := &ChatRequest{ + Model: c.Provider.Model, + Messages: messages, + Temperature: 0.8, + TopP: 1, + MaxTokens: 100, + N: 1, + StopWords: []string{}, + Stream: false, + FrequencyPenalty: 0, + PresencePenalty: 0, + LogitBias: nil, + User: nil, + Seed: nil, + Tools: []string{}, + ToolChoice: nil, + ResponseFormat: nil, + } + + // Use reflection to dynamically assign default parameter values + defaultParams := c.Provider.DefaultParams + v := reflect.ValueOf(chatRequest).Elem() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + fieldName := field.Name + defaultValue, ok := defaultParams[fieldName] + if ok && defaultValue != nil { + fieldValue := v.FieldByName(fieldName) + if fieldValue.IsValid() && fieldValue.CanSet() { + fieldValue.Set(reflect.ValueOf(defaultValue)) + } + } + } + return chatRequest +} -// CreateChat creates chat request. -func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { +// CreateChatResponse creates chat Response. +func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { if r.Model == "" { - if c.Model == "" { + if c.Provider.Model == "" { r.Model = defaultChatModel } else { - r.Model = c.Model + r.Model = c.Provider.Model } } @@ -156,9 +236,9 @@ func IsAzure(apiType APIType) bool { func (c *Client) setHeaders(req *protocol.Request) { req.Header.Set("Content-Type", "application/json") if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { - req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) } else { - req.Header.Set("api-key", c.apiKey) + req.Header.Set("api-key", c.Provider.ApiKey) } if c.organization != "" { req.Header.Set("OpenAI-Organization", c.organization) @@ -183,4 +263,16 @@ func (c *Client) buildAzureURL(suffix string, model string) string { return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", baseURL, model, suffix, c.apiVersion, ) -} \ No newline at end of file +} + +func main() { + + c := &Client{} + + c, err := c.Init("pool1", "gpt-3.5-turbo", "openai") + if err != nil { + // Handle the error + } + + +} \ No newline at end of file diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 0143294c..0f3a1262 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -16,4 +16,12 @@ type Provider struct { APIKey string `yaml:"api_key"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` +} + +type RequestBody struct { + Message []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + MessageHistory []string `json:"messageHistory"` } \ No newline at end of file From f791d7a91c849c29af977b4aad271288703157cf Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 15:43:14 -0700 Subject: [PATCH 14/61] #29: Fix import and log package names --- pkg/providers/openai/chat.go | 96 +++++++++++++++++++++++-- pkg/providers/openai/openaiclient.go | 102 +-------------------------- 2 files changed, 93 insertions(+), 105 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 27da6706..022f1125 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -6,9 +6,10 @@ import ( "context" "encoding/json" "fmt" - "log" - "net/http" + "log/slog" + "reflect" "strings" + "net/http" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -68,6 +69,68 @@ type ChatUsage struct { TotalTokens int `json:"total_tokens"` } +func (c *Client) CreateChatRequest(message []byte) *ChatRequest { + + + err := json.Unmarshal(message, &requestBody) + if err != nil { + slog.Error("Error:", err) + return nil + } + + var messages []*ChatMessage + for _, msg := range requestBody.Message { + chatMsg := &ChatMessage{ + Role: msg.Role, + Content: msg.Content, + } + if msg.Role == "user" { + chatMsg.Content += " " + strings.Join(requestBody.MessageHistory, " ") + } + messages = append(messages, chatMsg) + } + + // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value + + chatRequest := &ChatRequest{ + Model: c.Provider.Model, + Messages: messages, + Temperature: 0.8, + TopP: 1, + MaxTokens: 100, + N: 1, + StopWords: []string{}, + Stream: false, + FrequencyPenalty: 0, + PresencePenalty: 0, + LogitBias: nil, + User: nil, + Seed: nil, + Tools: []string{}, + ToolChoice: nil, + ResponseFormat: nil, + } + + // Use reflection to dynamically assign default parameter values + defaultParams := c.Provider.DefaultParams + v := reflect.ValueOf(chatRequest).Elem() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + fieldName := field.Name + defaultValue, ok := defaultParams[fieldName] + if ok && defaultValue != nil { + fieldValue := v.FieldByName(fieldName) + if fieldValue.IsValid() && fieldValue.CanSet() { + fieldValue.Set(reflect.ValueOf(defaultValue)) + } + } + } + + return chatRequest +} + + // ChatResponse is a response to a chat request. type ChatResponse struct { ID string `json:"id,omitempty"` @@ -98,6 +161,27 @@ type StreamedChatResponsePayload struct { } `json:"choices,omitempty"` } +// CreateChatResponse creates chat Response. +func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { + if r.Model == "" { + if c.Provider.Model == "" { + r.Model = defaultChatModel + } else { + r.Model = c.Provider.Model + } + } + + resp, err := c.createChat(ctx, r) + if err != nil { + return nil, err + } + if len(resp.Choices) == 0 { + return nil, ErrEmptyResponse + } + return resp, nil +} + + func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { if payload.StreamingFunc != nil { payload.Stream = true @@ -116,7 +200,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes req := &protocol.Request{} res := &protocol.Response{} req.Header.SetMethod(consts.MethodPost) - req.SetRequestURI(c.buildURL("/chat/completions", c.Model)) + req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model)) req.SetBody(payloadBytes) c.setHeaders(req) // sets additional headers @@ -153,7 +237,7 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo continue } if !strings.HasPrefix(line, "data:") { - log.Fatalf("unexpected line: %v", line) + slog.Warn("unexpected line:" + line) } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { @@ -162,12 +246,12 @@ func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, paylo var streamPayload StreamedChatResponsePayload err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload) if err != nil { - log.Fatalf("failed to decode stream payload: %v", err) + slog.Error("failed to decode stream payload: %v", err) } responseChan <- streamPayload } if err := scanner.Err(); err != nil { - log.Println("issue scanning response:", err) + slog.Error("issue scanning response:", err) } }() diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 2f741eeb..52a6a646 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,15 +1,12 @@ package openai import ( - "context" "errors" "fmt" "strings" "log/slog" "gopkg.in/yaml.v2" "os" - "json" - "reflect" //"Glide/pkg/providers" @@ -74,7 +71,6 @@ type Client struct { // required when APIType is APITypeAzure or APITypeAzureAD apiVersion string - embeddingsModel string } // Option is an option for the OpenAI client. @@ -108,7 +104,8 @@ func (c *Client) Init(poolName string, modelName string, providerName string) (* // Check if the pool was found if selectedPool == nil { - slog.Error("pool '%s' not found", poolName) + slog.Error("pool not found") + return nil, fmt.Errorf("pool not found: %s", poolName) } // Find the OpenAI provider in the selected pool with the specified model @@ -123,6 +120,7 @@ func (c *Client) Init(poolName string, modelName string, providerName string) (* // Check if the provider was found if selectedProvider == nil { slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) + return nil, fmt.Errorf("provider for model '%s' not found in pool '%s'", modelName, poolName) } // Create clients for each OpenAI provider @@ -147,88 +145,6 @@ func HttpClient() *client.Client { return c } - -func (c *Client) CreateChatRequest(message []byte) *ChatRequest { - - - err := json.Unmarshal(message, &requestBody) - if err != nil { - slog.Error("Error:", err) - return nil - } - - var messages []*ChatMessage - for _, msg := range requestBody.Message { - chatMsg := &ChatMessage{ - Role: msg.Role, - Content: msg.Content, - } - if msg.Role == "user" { - chatMsg.Content += " " + strings.Join(requestBody.MessageHistory, " ") - } - messages = append(messages, chatMsg) - } - - // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value - - chatRequest := &ChatRequest{ - Model: c.Provider.Model, - Messages: messages, - Temperature: 0.8, - TopP: 1, - MaxTokens: 100, - N: 1, - StopWords: []string{}, - Stream: false, - FrequencyPenalty: 0, - PresencePenalty: 0, - LogitBias: nil, - User: nil, - Seed: nil, - Tools: []string{}, - ToolChoice: nil, - ResponseFormat: nil, - } - - // Use reflection to dynamically assign default parameter values - defaultParams := c.Provider.DefaultParams - v := reflect.ValueOf(chatRequest).Elem() - t := v.Type() - for i := 0; i < v.NumField(); i++ { - field := t.Field(i) - fieldName := field.Name - defaultValue, ok := defaultParams[fieldName] - if ok && defaultValue != nil { - fieldValue := v.FieldByName(fieldName) - if fieldValue.IsValid() && fieldValue.CanSet() { - fieldValue.Set(reflect.ValueOf(defaultValue)) - } - } - } - - return chatRequest -} - -// CreateChatResponse creates chat Response. -func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { - if r.Model == "" { - if c.Provider.Model == "" { - r.Model = defaultChatModel - } else { - r.Model = c.Provider.Model - } - } - - resp, err := c.createChat(ctx, r) - if err != nil { - return nil, err - } - if len(resp.Choices) == 0 { - return nil, ErrEmptyResponse - } - return resp, nil -} - func IsAzure(apiType APIType) bool { return apiType == APITypeAzure || apiType == APITypeAzureAD } @@ -264,15 +180,3 @@ func (c *Client) buildAzureURL(suffix string, model string) string { baseURL, model, suffix, c.apiVersion, ) } - -func main() { - - c := &Client{} - - c, err := c.Init("pool1", "gpt-3.5-turbo", "openai") - if err != nil { - // Handle the error - } - - -} \ No newline at end of file From a5caa77f45eb2b2c328f445148248173ea87530d Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 15:55:57 -0700 Subject: [PATCH 15/61] #29: add run method --- pkg/providers/openai/openaiclient.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 52a6a646..7a1b3ba8 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -7,6 +7,7 @@ import ( "log/slog" "gopkg.in/yaml.v2" "os" + "context" //"Glide/pkg/providers" @@ -76,6 +77,26 @@ type Client struct { // Option is an option for the OpenAI client. type Option func(*Client) error +func (c *Client) Run() (*ChatResponse, error) { + + c = &Client{} + + c, err := c.Init("pool1", "gpt-3.5-turbo", "openai") + if err != nil { + slog.Error("Error:" + err.Error()) + return nil, err + } + + // Create a new chat request + chatRequest := c.CreateChatRequest([]byte("hello world")) + + // Send the chat request + + resp, err := c.CreateChatResponse(context.Background(), chatRequest) + + return resp, err +} + func (c *Client) Init(poolName string, modelName string, providerName string) (*Client, error) { // Returns a []*Client of OpenAI @@ -135,7 +156,6 @@ func (c *Client) Init(poolName string, modelName string, providerName string) (* } - func HttpClient() *client.Client { c, err := client.NewClient() From db38f443a9b0a829803ebf19d4cdc362e946afbc Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 19:56:17 -0700 Subject: [PATCH 16/61] #29: tests passing --- .gitignore | 1 + pkg/providers/openai/chat.go | 145 +++++++++++++++---------- pkg/providers/openai/openai_test.go | 38 +++++++ pkg/providers/openai/openaiclient.go | 151 +++++++++------------------ pkg/providers/types.go | 11 +- 5 files changed, 188 insertions(+), 158 deletions(-) create mode 100644 pkg/providers/openai/openai_test.go diff --git a/.gitignore b/.gitignore index 951841c7..012b879f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea dist .env +config.yaml diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 022f1125..44c0fd6c 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "bytes" "context" "encoding/json" @@ -10,6 +9,7 @@ import ( "reflect" "strings" "net/http" + "io" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -127,6 +127,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { } } + fmt.Println(chatRequest) + return chatRequest } @@ -171,7 +173,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } } - resp, err := c.createChat(ctx, r) + resp, err := c.createChatHttp(ctx, r) if err != nil { return nil, err } @@ -183,6 +185,9 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { + + slog.Info("running createChat") + if payload.StreamingFunc != nil { payload.Stream = true } @@ -202,80 +207,110 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes req.Header.SetMethod(consts.MethodPost) req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model)) req.SetBody(payloadBytes) + req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Content-Type", "application/json") - c.setHeaders(req) // sets additional headers + slog.Info("making request") // Send request err = c.httpClient.Do(ctx, req, res) //*client.Client if err != nil { + slog.Error(err.Error()) + fmt.Println(res.Body()) return nil, err } + slog.Info("request returned") + defer res.ConnectionClose() // replaced r.Body.Close() + + + slog.Info(fmt.Sprintf("%d", res.StatusCode())) + if res.StatusCode() != http.StatusOK { msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode()) return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 } - if payload.StreamingFunc != nil { - return parseStreamingChatResponse(ctx, res, payload) - } + // Parse response var response ChatResponse return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) } -func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, payload *ChatRequest) (*ChatResponse, error) { - scanner := bufio.NewScanner(bytes.NewReader(r.Body())) - responseChan := make(chan StreamedChatResponsePayload) - go func() { - defer close(responseChan) - for scanner.Scan() { - line := scanner.Text() - if line == "" { - continue - } - if !strings.HasPrefix(line, "data:") { - slog.Warn("unexpected line:" + line) - } - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - return - } - var streamPayload StreamedChatResponsePayload - err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload) - if err != nil { - slog.Error("failed to decode stream payload: %v", err) - } - responseChan <- streamPayload - } - if err := scanner.Err(); err != nil { - slog.Error("issue scanning response:", err) - } - }() +func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { + slog.Info("running createChatHttp") - // Parse response - response := ChatResponse{ - Choices: []*ChatChoice{ - {}, - }, + if payload.StreamingFunc != nil { + payload.Stream = true + } + // Build request payload + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err } - for streamResponse := range responseChan { - if len(streamResponse.Choices) == 0 { - continue - } - chunk := []byte(streamResponse.Choices[0].Delta.Content) - response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content - response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason - - if payload.StreamingFunc != nil { - err := payload.StreamingFunc(ctx, chunk) - if err != nil { - return nil, fmt.Errorf("streaming func returned an error: %w", err) - } - } + // Build request + if c.baseURL == "" { + c.baseURL = defaultBaseURL + } + + reqBody := bytes.NewBuffer(payloadBytes) + req, err := http.NewRequest("POST", c.buildURL("/chat/completions", c.Provider.Model), reqBody) + if err != nil { + slog.Error(err.Error()) + return nil, err } - return &response, nil -} \ No newline at end of file + + req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + slog.Error(err.Error()) + return nil, err + } + defer resp.Body.Close() + + slog.Info(fmt.Sprintf("%d", resp.StatusCode)) + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + slog.Error(err.Error()) +} + bodyString := string(bodyBytes) + slog.Info(bodyString) + + // Parse response + var response ChatResponse + return &response, json.NewDecoder(resp.Body).Decode(&response) +} + +func IsAzure(apiType APIType) bool { + return apiType == APITypeAzure || apiType == APITypeAzureAD +} + + +func (c *Client) buildURL(suffix string, model string) string { + if IsAzure(c.apiType) { + return c.buildAzureURL(suffix, model) + } + + slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix)) + + // open ai implement: + return fmt.Sprintf("%s%s", c.baseURL, suffix) +} + +func (c *Client) buildAzureURL(suffix string, model string) string { + baseURL := c.baseURL + baseURL = strings.TrimRight(baseURL, "/") + + // azure example url: + // /openai/deployments/{model}/chat/completions?api-version={api_version} + return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", + baseURL, model, suffix, c.apiVersion, + ) +} diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go new file mode 100644 index 00000000..57be6eef --- /dev/null +++ b/pkg/providers/openai/openai_test.go @@ -0,0 +1,38 @@ +package openai + +import ( + "testing" + "encoding/json" + "fmt" +) + + + +func TestOpenAIClient(t *testing.T) { + // Initialize the OpenAI client + + poolName := "default" + modelName := "gpt-3.5-turbo" + + payload := map[string]interface{}{ + "message": []map[string]string{ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "tell me a joke", + }, + }, + "messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"}, + } + + payloadBytes, _ := json.Marshal(payload) + + c := &Client{} + + resp, err := c.Run(poolName, modelName, payloadBytes) + + fmt.Println(resp, err) +} \ No newline at end of file diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 7a1b3ba8..758a7c36 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,38 +1,35 @@ package openai import ( + "context" "errors" "fmt" - "strings" - "log/slog" "gopkg.in/yaml.v2" + "log/slog" "os" - "context" - //"Glide/pkg/providers" + "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" - "github.com/cloudwego/hertz/pkg/protocol" - ) const ( defaultBaseURL = "https://api.openai.com/v1" - defaultOrganization = "" - defaultFunctionCallBehavior = "auto" + defaultOrganization = "" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. var ( ErrEmptyResponse = errors.New("empty response") - requestBody struct { - Message []struct { + requestBody struct { + Message []struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` MessageHistory []string `json:"messageHistory"` } ) + type APIType string const ( @@ -40,83 +37,73 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" ) - -type GatewayConfig struct { - Pools []Pool `yaml:"pools"` -} - -type Pool struct { - Name string `yaml:"name"` - Balancing string `yaml:"balancing"` - Providers []Provider `yaml:"providers"` -} - -type Provider struct { - Name string `yaml:"name"` - Provider string `yaml:"provider"` - Model string `yaml:"model"` - ApiKey string `yaml:"api_key"` - TimeoutMs int `yaml:"timeout_ms,omitempty"` - DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` -} - - - // Client is a client for the OpenAI API. type Client struct { - Provider Provider + Provider providers.Provider baseURL string organization string apiType APIType httpClient *client.Client // required when APIType is APITypeAzure or APITypeAzureAD - apiVersion string + apiVersion string } -// Option is an option for the OpenAI client. -type Option func(*Client) error -func (c *Client) Run() (*ChatResponse, error) { - - c = &Client{} +func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { - c, err := c.Init("pool1", "gpt-3.5-turbo", "openai") + c, err := c.NewClient(poolName, modelName) if err != nil { slog.Error("Error:" + err.Error()) return nil, err } // Create a new chat request - chatRequest := c.CreateChatRequest([]byte("hello world")) + + slog.Info("creating chat request") + + chatRequest := c.CreateChatRequest(payload) + + slog.Info("chat request created") // Send the chat request + slog.Info("sending chat request") + resp, err := c.CreateChatResponse(context.Background(), chatRequest) return resp, err } - -func (c *Client) Init(poolName string, modelName string, providerName string) (*Client, error) { +func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Returns a []*Client of OpenAI + // modelName is determined by the model pool + // poolName is determined by the route the request came from + + var providerName = "openai" // Read the YAML file - data, err := os.ReadFile("config.yaml") + data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") + if err != nil { slog.Error("Failed to read file: %v", err) + return nil, err } + slog.Info("config found") + // Unmarshal the YAML data into your struct - var config GatewayConfig + var config providers.GatewayConfig err = yaml.Unmarshal(data, &config) if err != nil { slog.Error("Failed to unmarshal YAML: %v", err) } + fmt.Println(config) + // Find the pool with the specified name - var selectedPool *Pool - for _, pool := range config.Pools { + var selectedPool *providers.Pool + for _, pool := range config.Gateway.Pools { if pool.Name == poolName { selectedPool = &pool break @@ -130,30 +117,30 @@ func (c *Client) Init(poolName string, modelName string, providerName string) (* } // Find the OpenAI provider in the selected pool with the specified model - var selectedProvider *Provider - for _, provider := range selectedPool.Providers { - if provider.Name == providerName && provider.Model == modelName { - selectedProvider = &provider - break - } - } - - // Check if the provider was found - if selectedProvider == nil { - slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) + var selectedProvider *providers.Provider + for _, provider := range selectedPool.Providers { + if provider.Provider == providerName && provider.Model == modelName { + selectedProvider = &provider + break + } + } + + // Check if the provider was found + if selectedProvider == nil { + slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) return nil, fmt.Errorf("provider for model '%s' not found in pool '%s'", modelName, poolName) - } + } // Create clients for each OpenAI provider client := &Client{ - Provider: *selectedProvider, - organization: defaultOrganization, - apiType: APITypeOpenAI, - httpClient: HttpClient(), - } + Provider: *selectedProvider, + organization: defaultOrganization, + apiType: APITypeOpenAI, + httpClient: HttpClient(), + } return client, nil - + } func HttpClient() *client.Client { @@ -165,38 +152,4 @@ func HttpClient() *client.Client { return c } -func IsAzure(apiType APIType) bool { - return apiType == APITypeAzure || apiType == APITypeAzureAD -} - -func (c *Client) setHeaders(req *protocol.Request) { - req.Header.Set("Content-Type", "application/json") - if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { - req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) - } else { - req.Header.Set("api-key", c.Provider.ApiKey) - } - if c.organization != "" { - req.Header.Set("OpenAI-Organization", c.organization) - } -} -func (c *Client) buildURL(suffix string, model string) string { - if IsAzure(c.apiType) { - return c.buildAzureURL(suffix, model) - } - - // open ai implement: - return fmt.Sprintf("%s%s", c.baseURL, suffix) -} - -func (c *Client) buildAzureURL(suffix string, model string) string { - baseURL := c.baseURL - baseURL = strings.TrimRight(baseURL, "/") - - // azure example url: - // /openai/deployments/{model}/chat/completions?api-version={api_version} - return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", - baseURL, model, suffix, c.apiVersion, - ) -} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 0f3a1262..7e4dd8e4 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,19 +1,22 @@ -package provider +package providers + type GatewayConfig struct { + Gateway PoolsConfig `yaml:"gateway"` +} +type PoolsConfig struct { Pools []Pool `yaml:"pools"` } type Pool struct { - Name string `yaml:"name"` + Name string `yaml:"pool"` Balancing string `yaml:"balancing"` Providers []Provider `yaml:"providers"` } type Provider struct { - Name string `yaml:"name"` Provider string `yaml:"provider"` Model string `yaml:"model"` - APIKey string `yaml:"api_key"` + ApiKey string `yaml:"api_key"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } From 463401b3a59890905fab3e120ad97d444c596eaf Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:03:32 -0700 Subject: [PATCH 17/61] #29: tests passing --- pkg/providers/openai/chat.go | 40 +++++++++++++---------------- pkg/providers/openai/openai_test.go | 6 +++-- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 44c0fd6c..bf0cac93 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -5,11 +5,11 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" + "net/http" "reflect" "strings" - "net/http" - "io" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -52,7 +52,6 @@ type ChatMessage struct { // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, // with a maximum length of 64 characters. Name string `json:"name,omitempty"` - } // ChatChoice is a choice in a chat response. @@ -71,7 +70,6 @@ type ChatUsage struct { func (c *Client) CreateChatRequest(message []byte) *ChatRequest { - err := json.Unmarshal(message, &requestBody) if err != nil { slog.Error("Error:", err) @@ -91,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { } // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value - + chatRequest := &ChatRequest{ Model: c.Provider.Model, Messages: messages, @@ -132,7 +130,6 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { return chatRequest } - // ChatResponse is a response to a chat request. type ChatResponse struct { ID string `json:"id,omitempty"` @@ -156,8 +153,8 @@ type StreamedChatResponsePayload struct { Choices []struct { Index float64 `json:"index,omitempty"` Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` } `json:"delta,omitempty"` FinishReason string `json:"finish_reason,omitempty"` } `json:"choices,omitempty"` @@ -172,7 +169,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR r.Model = c.Provider.Model } } - + resp, err := c.createChatHttp(ctx, r) if err != nil { return nil, err @@ -183,11 +180,10 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR return resp, nil } +func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { -func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { - slog.Info("running createChat") - + if payload.StreamingFunc != nil { payload.Stream = true } @@ -224,8 +220,6 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes defer res.ConnectionClose() // replaced r.Body.Close() - - slog.Info(fmt.Sprintf("%d", res.StatusCode())) if res.StatusCode() != http.StatusOK { @@ -233,7 +227,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 } - + // Parse response var response ChatResponse return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) @@ -276,12 +270,15 @@ func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*Cha slog.Info(fmt.Sprintf("%d", resp.StatusCode)) - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - slog.Error(err.Error()) -} - bodyString := string(bodyBytes) - slog.Info(bodyString) + if resp.StatusCode != http.StatusOK { + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + slog.Error(err.Error()) + } + bodyString := string(bodyBytes) + slog.Warn(bodyString) + } // Parse response var response ChatResponse @@ -292,7 +289,6 @@ func IsAzure(apiType APIType) bool { return apiType == APITypeAzure || apiType == APITypeAzureAD } - func (c *Client) buildURL(suffix string, model string) string { if IsAzure(c.apiType) { return c.buildAzureURL(suffix, model) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 57be6eef..9badd620 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -32,7 +32,9 @@ func TestOpenAIClient(t *testing.T) { c := &Client{} - resp, err := c.Run(poolName, modelName, payloadBytes) + resp, _ := c.Run(poolName, modelName, payloadBytes) - fmt.Println(resp, err) + respJSON, _ := json.Marshal(resp) + + fmt.Println(string(respJSON)) } \ No newline at end of file From e16c13c840d0fd46eb04f3c235296959a590a0a1 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:13:56 -0700 Subject: [PATCH 18/61] #29: add todo --- pkg/providers/openai/openaiclient.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 758a7c36..80d6be81 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,3 +1,7 @@ +//TODO: Explore resource pooling +// TODO: Optimize Type use +// TODO: Explore Hertz TLS & resource pooling + package openai import ( From c9d94026d8b872c55ad7d1d553b1177183c4f083 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:26:37 -0700 Subject: [PATCH 19/61] #29: go mod tody --- go.mod | 7 ------- go.sum | 13 ++----------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index a300639b..3fdf7da3 100644 --- a/go.mod +++ b/go.mod @@ -17,9 +17,6 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cloudwego/netpoll v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/henrylee2cn/ameda v1.4.10 // indirect @@ -27,7 +24,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/kr/text v0.2.0 // indirect - github.com/leodido/go-urn v1.2.4 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.8.2 // indirect @@ -36,9 +32,6 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index a33a9e0d..725f36f7 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,6 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -87,26 +85,19 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 2f3695a84bb139a8efd58b6dbc23a76ebe1985ac Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:32:19 -0700 Subject: [PATCH 20/61] #29: gofmt --- .gitignore | 1 + pkg/providers/openai/openai_test.go | 34 +++++++++++++--------------- pkg/providers/openai/openaiclient.go | 9 ++++---- pkg/providers/types.go | 20 ++++++++-------- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 6fd6161c..d561a8b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ bin glide tmp coverage.txt +precommit.txt diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 9badd620..f32ebc80 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -1,13 +1,11 @@ package openai import ( - "testing" "encoding/json" "fmt" + "testing" ) - - func TestOpenAIClient(t *testing.T) { // Initialize the OpenAI client @@ -15,20 +13,20 @@ func TestOpenAIClient(t *testing.T) { modelName := "gpt-3.5-turbo" payload := map[string]interface{}{ - "message": []map[string]string{ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - { - "role": "user", - "content": "tell me a joke", - }, - }, - "messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"}, - } - - payloadBytes, _ := json.Marshal(payload) + "message": []map[string]string{ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "tell me a joke", + }, + }, + "messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"}, + } + + payloadBytes, _ := json.Marshal(payload) c := &Client{} @@ -37,4 +35,4 @@ func TestOpenAIClient(t *testing.T) { respJSON, _ := json.Marshal(resp) fmt.Println(string(respJSON)) -} \ No newline at end of file +} diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 80d6be81..4688325c 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -18,8 +18,8 @@ import ( ) const ( - defaultBaseURL = "https://api.openai.com/v1" - defaultOrganization = "" + defaultBaseURL = "https://api.openai.com/v1" + defaultOrganization = "" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -41,6 +41,7 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" ) + // Client is a client for the OpenAI API. type Client struct { Provider providers.Provider @@ -53,7 +54,6 @@ type Client struct { apiVersion string } - func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { c, err := c.NewClient(poolName, modelName) @@ -88,7 +88,7 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Read the YAML file data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") - + if err != nil { slog.Error("Failed to read file: %v", err) return nil, err @@ -156,4 +156,3 @@ func HttpClient() *client.Client { return c } - diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 7e4dd8e4..022ced99 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -8,23 +8,23 @@ type PoolsConfig struct { } type Pool struct { - Name string `yaml:"pool"` - Balancing string `yaml:"balancing"` - Providers []Provider `yaml:"providers"` + Name string `yaml:"pool"` + Balancing string `yaml:"balancing"` + Providers []Provider `yaml:"providers"` } type Provider struct { - Provider string `yaml:"provider"` - Model string `yaml:"model"` - ApiKey string `yaml:"api_key"` - TimeoutMs int `yaml:"timeout_ms,omitempty"` - DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` + Provider string `yaml:"provider"` + Model string `yaml:"model"` + ApiKey string `yaml:"api_key"` + TimeoutMs int `yaml:"timeout_ms,omitempty"` + DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } type RequestBody struct { - Message []struct { + Message []struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` MessageHistory []string `json:"messageHistory"` -} \ No newline at end of file +} From 519128e8bd08ba408eeb1c107f8704f90f91d397 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:40:14 -0700 Subject: [PATCH 21/61] #29: gofumpt --- pkg/providers/openai/chat.go | 2 -- pkg/providers/openai/openaiclient.go | 12 ++++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index bf0cac93..58b9088c 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -69,7 +69,6 @@ type ChatUsage struct { } func (c *Client) CreateChatRequest(message []byte) *ChatRequest { - err := json.Unmarshal(message, &requestBody) if err != nil { slog.Error("Error:", err) @@ -181,7 +180,6 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { - slog.Info("running createChat") if payload.StreamingFunc != nil { diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 4688325c..7ef403eb 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,4 +1,4 @@ -//TODO: Explore resource pooling +// TODO: Explore resource pooling // TODO: Optimize Type use // TODO: Explore Hertz TLS & resource pooling @@ -8,10 +8,11 @@ import ( "context" "errors" "fmt" - "gopkg.in/yaml.v2" "log/slog" "os" + "gopkg.in/yaml.v2" + "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" @@ -55,7 +56,6 @@ type Client struct { } func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { - c, err := c.NewClient(poolName, modelName) if err != nil { slog.Error("Error:" + err.Error()) @@ -84,11 +84,10 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // modelName is determined by the model pool // poolName is determined by the route the request came from - var providerName = "openai" + providerName := "openai" // Read the YAML file data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") - if err != nil { slog.Error("Failed to read file: %v", err) return nil, err @@ -144,15 +143,12 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { } return client, nil - } func HttpClient() *client.Client { - c, err := client.NewClient() if err != nil { slog.Error(err.Error()) } return c - } From 1d48263171f4c78dbed06d1c16112996c1b520d9 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:54:16 -0700 Subject: [PATCH 22/61] #29: lint --- pkg/providers/openai/chat.go | 14 ++++++-------- pkg/providers/openai/openaiclient.go | 4 ++-- pkg/providers/types.go | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 58b9088c..5935bc00 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -10,9 +10,6 @@ import ( "net/http" "reflect" "strings" - - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/protocol/consts" ) const ( @@ -169,7 +166,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } } - resp, err := c.createChatHttp(ctx, r) + resp, err := c.createChatHttp(r) if err != nil { return nil, err } @@ -178,7 +175,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } return resp, nil } - +/* will remove later func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { slog.Info("running createChat") @@ -201,7 +198,7 @@ func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*Ch req.Header.SetMethod(consts.MethodPost) req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model)) req.SetBody(payloadBytes) - req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") slog.Info("making request") @@ -230,8 +227,9 @@ func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*Ch var response ChatResponse return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) } +*/ -func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { +func (c *Client) createChatHttp(payload *ChatRequest) (*ChatResponse, error) { slog.Info("running createChatHttp") if payload.StreamingFunc != nil { @@ -255,7 +253,7 @@ func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*Cha return nil, err } - req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") httpClient := &http.Client{} diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 7ef403eb..6ba2a63b 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -139,13 +139,13 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { Provider: *selectedProvider, organization: defaultOrganization, apiType: APITypeOpenAI, - httpClient: HttpClient(), + httpClient: HTTPClient(), } return client, nil } -func HttpClient() *client.Client { +func HTTPClient() *client.Client { c, err := client.NewClient() if err != nil { slog.Error(err.Error()) diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 022ced99..e9ce7901 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -16,7 +16,7 @@ type Pool struct { type Provider struct { Provider string `yaml:"provider"` Model string `yaml:"model"` - ApiKey string `yaml:"api_key"` + APIKey string `yaml:"api_key"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } From f54f3008eaa828034f2126117417d7763c07184e Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:56:14 -0700 Subject: [PATCH 23/61] #29: lint --- pkg/providers/openai/chat.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 5935bc00..b28e74f1 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -175,6 +175,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } return resp, nil } + /* will remove later func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { slog.Info("running createChat") From 843997d2e352002c3782364765ea3c546bb2c9d6 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 21:04:33 -0700 Subject: [PATCH 24/61] #29: lint --- pkg/providers/openai/chat.go | 6 ++++-- pkg/providers/openai/openai_test.go | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index b28e74f1..f5be4c0b 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -166,7 +166,9 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } } - resp, err := c.createChatHttp(r) + _ = ctx // keep this for future use + + resp, err := c.createChatHTTP(r) if err != nil { return nil, err } @@ -230,7 +232,7 @@ func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*Ch } */ -func (c *Client) createChatHttp(payload *ChatRequest) (*ChatResponse, error) { +func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { slog.Info("running createChatHttp") if payload.StreamingFunc != nil { diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index f32ebc80..f070b28f 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -9,6 +9,8 @@ import ( func TestOpenAIClient(t *testing.T) { // Initialize the OpenAI client + _ = t + poolName := "default" modelName := "gpt-3.5-turbo" From 4f9eda084f84fb4d930a6ebc2f56042585fcefd2 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 21:18:34 -0700 Subject: [PATCH 25/61] #29: fix Implicit memory aliasing --- pkg/providers/openai/openai_test.go | 2 +- pkg/providers/openai/openaiclient.go | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index f070b28f..7e738b7c 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -9,7 +9,7 @@ import ( func TestOpenAIClient(t *testing.T) { // Initialize the OpenAI client - _ = t + var _ = t poolName := "default" modelName := "gpt-3.5-turbo" diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 6ba2a63b..60275fde 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -106,12 +106,13 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Find the pool with the specified name var selectedPool *providers.Pool - for _, pool := range config.Gateway.Pools { + for i := range config.Gateway.Pools { + pool := &config.Gateway.Pools[i] if pool.Name == poolName { - selectedPool = &pool + selectedPool = pool break } - } +} // Check if the pool was found if selectedPool == nil { @@ -121,9 +122,10 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Find the OpenAI provider in the selected pool with the specified model var selectedProvider *providers.Provider - for _, provider := range selectedPool.Providers { + for i := range selectedPool.Providers { + provider := &selectedPool.Providers[i] if provider.Provider == providerName && provider.Model == modelName { - selectedProvider = &provider + selectedProvider = provider break } } From 388f01c4638e73e8b13f82b3c3dfbdc1a3f81e38 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 21:19:26 -0700 Subject: [PATCH 26/61] #29: lint --- pkg/providers/openai/openai_test.go | 2 +- pkg/providers/openai/openaiclient.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 7e738b7c..f070b28f 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -9,7 +9,7 @@ import ( func TestOpenAIClient(t *testing.T) { // Initialize the OpenAI client - var _ = t + _ = t poolName := "default" modelName := "gpt-3.5-turbo" diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 60275fde..8b85b8d4 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -112,7 +112,7 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { selectedPool = pool break } -} + } // Check if the pool was found if selectedPool == nil { From f38473e6b3753ef2a84bb68c02131296f1404c7b Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 07:36:14 -0700 Subject: [PATCH 27/61] #29: update request from defaultParams --- pkg/providers/openai/chat.go | 52 ++++++++++--------- pkg/providers/openai/openai_test.go | 4 +- pkg/providers/openai/openaiclient.go | 76 +++++++++++++++++----------- 3 files changed, 76 insertions(+), 56 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index f5be4c0b..40ec17f4 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -8,8 +8,8 @@ import ( "io" "log/slog" "net/http" - "reflect" "strings" + "reflect" ) const ( @@ -20,10 +20,10 @@ const ( type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` Messages []*ChatMessage `json:"messages" validate:"required"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` - N int `json:"n,omitempty" validate:"omitempty,gte=1"` + Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=1"` + TopP float64 `json:"top_p" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` + N int `json:"n" validate:"omitempty,gte=1"` StopWords []string `json:"stop,omitempty"` Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` FrequencyPenalty int `json:"frequency_penalty,omitempty"` @@ -72,6 +72,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { return nil } + slog.Info("creating chatRequest from payload") + var messages []*ChatMessage for _, msg := range requestBody.Message { chatMsg := &ChatMessage{ @@ -87,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value chatRequest := &ChatRequest{ - Model: c.Provider.Model, + Model: c.setModel(), Messages: messages, Temperature: 0.8, TopP: 1, @@ -107,21 +109,20 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { // Use reflection to dynamically assign default parameter values defaultParams := c.Provider.DefaultParams - v := reflect.ValueOf(chatRequest).Elem() - t := v.Type() - for i := 0; i < v.NumField(); i++ { - field := t.Field(i) - fieldName := field.Name - defaultValue, ok := defaultParams[fieldName] - if ok && defaultValue != nil { - fieldValue := v.FieldByName(fieldName) - if fieldValue.IsValid() && fieldValue.CanSet() { - fieldValue.Set(reflect.ValueOf(defaultValue)) - } + + chatRequestValue := reflect.ValueOf(chatRequest).Elem() + chatRequestType := chatRequestValue.Type() + + for i := 0; i < chatRequestValue.NumField(); i++ { + jsonTag := chatRequestType.Field(i).Tag.Get("json") + fmt.Println(jsonTag) + if value, ok := defaultParams[jsonTag]; ok { + fieldValue := chatRequestValue.Field(i) + fieldValue.Set(reflect.ValueOf(value)) } } - fmt.Println(chatRequest) + fmt.Println(chatRequest, defaultParams) return chatRequest } @@ -158,13 +159,6 @@ type StreamedChatResponsePayload struct { // CreateChatResponse creates chat Response. func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { - if r.Model == "" { - if c.Provider.Model == "" { - r.Model = defaultChatModel - } else { - r.Model = c.Provider.Model - } - } _ = ctx // keep this for future use @@ -309,3 +303,11 @@ func (c *Client) buildAzureURL(suffix string, model string) string { baseURL, model, suffix, c.apiVersion, ) } + +func (c *Client) setModel() string { + if c.Provider.Model == "" { + return defaultChatModel + } else { + return c.Provider.Model + } +} diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index f070b28f..93709a9b 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -30,9 +30,9 @@ func TestOpenAIClient(t *testing.T) { payloadBytes, _ := json.Marshal(payload) - c := &Client{} + c, _ := OpenAiClient(poolName, modelName, payloadBytes) - resp, _ := c.Run(poolName, modelName, payloadBytes) + resp, _ := c.Chat() respJSON, _ := json.Marshal(resp) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 8b85b8d4..bced1959 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,7 +1,7 @@ // TODO: Explore resource pooling // TODO: Optimize Type use // TODO: Explore Hertz TLS & resource pooling - +// OpenAI package provide a set of functions to interact with the OpenAI API. package openai import ( @@ -46,7 +46,9 @@ const ( // Client is a client for the OpenAI API. type Client struct { Provider providers.Provider + PoolName string baseURL string + payload []byte organization string apiType APIType httpClient *client.Client @@ -55,34 +57,16 @@ type Client struct { apiVersion string } -func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { - c, err := c.NewClient(poolName, modelName) - if err != nil { - slog.Error("Error:" + err.Error()) - return nil, err - } - - // Create a new chat request - - slog.Info("creating chat request") - - chatRequest := c.CreateChatRequest(payload) - - slog.Info("chat request created") - - // Send the chat request - - slog.Info("sending chat request") - - resp, err := c.CreateChatResponse(context.Background(), chatRequest) - - return resp, err -} - -func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { - // Returns a []*Client of OpenAI - // modelName is determined by the model pool - // poolName is determined by the route the request came from +// OpenAiClient creates a new client for the OpenAI API. +// +// Parameters: +// - poolName: The name of the pool to connect to. +// - modelName: The name of the model to use. +// +// Returns: +// - *Client: A pointer to the created client. +// - error: An error if the client creation failed. +func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { providerName := "openai" @@ -139,6 +123,9 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Create clients for each OpenAI provider client := &Client{ Provider: *selectedProvider, + PoolName: poolName, + baseURL: defaultBaseURL, + payload: payload, organization: defaultOrganization, apiType: APITypeOpenAI, httpClient: HTTPClient(), @@ -147,6 +134,37 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { return client, nil } +// Chat sends a chat request to the specified OpenAI model. +// +// Parameters: +// - payload: The user payload for the chat request. +// Returns: +// - *ChatResponse: a pointer to a ChatResponse +// - error: An error if the request failed. +func (c *Client) Chat() (*ChatResponse, error) { + + // Create a new chat request + + slog.Info("creating chat request") + + chatRequest := c.CreateChatRequest(c.payload) + + slog.Info("chat request created") + + // Send the chat request + + slog.Info("sending chat request") + + resp, err := c.CreateChatResponse(context.Background(), chatRequest) + + return resp, err +} + +// HTTPClient returns a new Hertz HTTP client. +// +// It creates a new client using the client.NewClient() function and returns the client. +// If an error occurs during the creation of the client, it logs the error using slog.Error(). +// The function returns the created client or nil if an error occurred. func HTTPClient() *client.Client { c, err := client.NewClient() if err != nil { From cc13f380c702e1ebd623f541f3396aca1c363a3c Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 07:40:58 -0700 Subject: [PATCH 28/61] #29: lint --- pkg/providers/openai/chat.go | 30 ++++++++++++---------------- pkg/providers/openai/openaiclient.go | 6 ++---- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 40ec17f4..545f442e 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -8,8 +8,8 @@ import ( "io" "log/slog" "net/http" - "strings" "reflect" + "strings" ) const ( @@ -24,16 +24,16 @@ type ChatRequest struct { TopP float64 `json:"top_p" validate:"omitempty,gte=0,lte=1"` MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` N int `json:"n" validate:"omitempty,gte=1"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` - User interface{} `json:"user,omitempty"` - Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + StopWords []string `json:"stop"` + Stream bool `json:"stream" validate:"omitempty, boolean"` + FrequencyPenalty int `json:"frequency_penalty"` + PresencePenalty int `json:"presence_penalty"` + LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` + User interface{} `json:"user"` + Seed interface{} `json:"seed" validate:"omitempty,gte=0"` + Tools []string `json:"tools"` + ToolChoice interface{} `json:"tool_choice"` + ResponseFormat interface{} `json:"response_format"` // StreamingFunc is a function to be called for each chunk of a streaming response. // Return an error to stop streaming early. @@ -115,15 +115,12 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { for i := 0; i < chatRequestValue.NumField(); i++ { jsonTag := chatRequestType.Field(i).Tag.Get("json") - fmt.Println(jsonTag) if value, ok := defaultParams[jsonTag]; ok { fieldValue := chatRequestValue.Field(i) fieldValue.Set(reflect.ValueOf(value)) } } - fmt.Println(chatRequest, defaultParams) - return chatRequest } @@ -159,7 +156,6 @@ type StreamedChatResponsePayload struct { // CreateChatResponse creates chat Response. func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { - _ = ctx // keep this for future use resp, err := c.createChatHTTP(r) @@ -307,7 +303,7 @@ func (c *Client) buildAzureURL(suffix string, model string) string { func (c *Client) setModel() string { if c.Provider.Model == "" { return defaultChatModel - } else { - return c.Provider.Model } + + return c.Provider.Model } diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index bced1959..13126c43 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -46,7 +46,7 @@ const ( // Client is a client for the OpenAI API. type Client struct { Provider providers.Provider - PoolName string + PoolName string baseURL string payload []byte organization string @@ -67,7 +67,6 @@ type Client struct { // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { - providerName := "openai" // Read the YAML file @@ -138,11 +137,10 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e // // Parameters: // - payload: The user payload for the chat request. -// Returns: +// Returns: // - *ChatResponse: a pointer to a ChatResponse // - error: An error if the request failed. func (c *Client) Chat() (*ChatResponse, error) { - // Create a new chat request slog.Info("creating chat request") From a9a9f0b8ddb0973e8b10af8c5576041be7db68b4 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 16:25:27 -0700 Subject: [PATCH 29/61] #29: lint --- .gitignore | 1 + pkg/providers/openai/chat.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d561a8b4..f6da76a8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ glide tmp coverage.txt precommit.txt +openai_test.go \ No newline at end of file diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 545f442e..316e19a2 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -304,6 +304,6 @@ func (c *Client) setModel() string { if c.Provider.Model == "" { return defaultChatModel } - + return c.Provider.Model } From fe9ad4f55782684cba486b452ef006db5d4bd61b Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 17:09:21 -0700 Subject: [PATCH 30/61] #29: emove unused variable in OpenAiClient constructor --- pkg/providers/openai/openaiclient.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 13126c43..33cc9717 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -126,7 +126,6 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e baseURL: defaultBaseURL, payload: payload, organization: defaultOrganization, - apiType: APITypeOpenAI, httpClient: HTTPClient(), } From bb852591eac466e4c794ed4a7b6a16c596de3369 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 17:51:23 -0700 Subject: [PATCH 31/61] #29: Update client --- go.mod | 9 +++- go.sum | 18 +++++++ pkg/providers/openai/chat.go | 73 ++++++++-------------------- pkg/providers/openai/openaiclient.go | 49 ++++++++----------- 4 files changed, 68 insertions(+), 81 deletions(-) diff --git a/go.mod b/go.mod index 3fdf7da3..834c468f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 + github.com/go-playground/validator/v10 v10.16.0 github.com/spf13/cobra v1.8.0 go.uber.org/goleak v1.3.0 go.uber.org/multierr v1.11.0 @@ -17,6 +18,9 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cloudwego/netpoll v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/henrylee2cn/ameda v1.4.10 // indirect @@ -24,14 +28,17 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.2.4 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.8.2 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/crypto v0.7.0 // indirect + golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.8.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index 725f36f7..cbd30fe6 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= +github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -44,6 +54,8 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -85,7 +97,11 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -93,6 +109,8 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 316e19a2..dcd5cec2 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -14,26 +14,27 @@ import ( const ( defaultChatModel = "gpt-3.5-turbo" + defaultEndpoint = "/chat/completions" ) // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` Messages []*ChatMessage `json:"messages" validate:"required"` - Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=1"` - TopP float64 `json:"top_p" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` - N int `json:"n" validate:"omitempty,gte=1"` - StopWords []string `json:"stop"` - Stream bool `json:"stream" validate:"omitempty, boolean"` - FrequencyPenalty int `json:"frequency_penalty"` - PresencePenalty int `json:"presence_penalty"` - LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` - User interface{} `json:"user"` - Seed interface{} `json:"seed" validate:"omitempty,gte=0"` - Tools []string `json:"tools"` - ToolChoice interface{} `json:"tool_choice"` - ResponseFormat interface{} `json:"response_format"` + Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"` + TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` + N int `json:"n,omitempty" validate:"omitempty,gte=1"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` + User interface{} `json:"user,omitempty"` + Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` // StreamingFunc is a function to be called for each chunk of a streaming response. // Return an error to stop streaming early. @@ -114,7 +115,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { chatRequestType := chatRequestValue.Type() for i := 0; i < chatRequestValue.NumField(); i++ { - jsonTag := chatRequestType.Field(i).Tag.Get("json") + jsonTags := strings.Split(chatRequestType.Field(i).Tag.Get("json"), ",") + jsonTag := jsonTags[0] if value, ok := defaultParams[jsonTag]; ok { fieldValue := chatRequestValue.Field(i) fieldValue.Set(reflect.ValueOf(value)) @@ -138,22 +140,6 @@ type ChatResponse struct { } `json:"usage,omitempty"` } -// StreamedChatResponsePayload is a chunk from the stream. -type StreamedChatResponsePayload struct { - ID string `json:"id,omitempty"` - Created float64 `json:"created,omitempty"` - Model string `json:"model,omitempty"` - Object string `json:"object,omitempty"` - Choices []struct { - Index float64 `json:"index,omitempty"` - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - } `json:"choices,omitempty"` -} - // CreateChatResponse creates chat Response. func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { _ = ctx // keep this for future use @@ -240,12 +226,14 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { } reqBody := bytes.NewBuffer(payloadBytes) - req, err := http.NewRequest("POST", c.buildURL("/chat/completions", c.Provider.Model), reqBody) + req, err := http.NewRequest("POST", c.buildURL(defaultEndpoint), reqBody) if err != nil { slog.Error(err.Error()) return nil, err } + fmt.Println("ReqBody" + reqBody.String()) + req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") @@ -274,32 +262,13 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { return &response, json.NewDecoder(resp.Body).Decode(&response) } -func IsAzure(apiType APIType) bool { - return apiType == APITypeAzure || apiType == APITypeAzureAD -} - -func (c *Client) buildURL(suffix string, model string) string { - if IsAzure(c.apiType) { - return c.buildAzureURL(suffix, model) - } - +func (c *Client) buildURL(suffix string) string { slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix)) // open ai implement: return fmt.Sprintf("%s%s", c.baseURL, suffix) } -func (c *Client) buildAzureURL(suffix string, model string) string { - baseURL := c.baseURL - baseURL = strings.TrimRight(baseURL, "/") - - // azure example url: - // /openai/deployments/{model}/chat/completions?api-version={api_version} - return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", - baseURL, model, suffix, c.apiVersion, - ) -} - func (c *Client) setModel() string { if c.Provider.Model == "" { return defaultChatModel diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 33cc9717..9ab937d5 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -16,11 +16,11 @@ import ( "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" + "github.com/go-playground/validator/v10" ) const ( - defaultBaseURL = "https://api.openai.com/v1" - defaultOrganization = "" + defaultBaseURL = "https://api.openai.com/v1" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -35,26 +35,13 @@ var ( } ) -type APIType string - -const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" -) - // Client is a client for the OpenAI API. type Client struct { - Provider providers.Provider - PoolName string - baseURL string - payload []byte - organization string - apiType APIType - httpClient *client.Client - - // required when APIType is APITypeAzure or APITypeAzureAD - apiVersion string + Provider providers.Provider `validate:"required"` + PoolName string `validate:"required"` + baseURL string `validate:"required"` + payload []byte `validate:"required"` + httpClient *client.Client `validate:"required"` } // OpenAiClient creates a new client for the OpenAI API. @@ -85,8 +72,6 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e slog.Error("Failed to unmarshal YAML: %v", err) } - fmt.Println(config) - // Find the pool with the specified name var selectedPool *providers.Pool for i := range config.Gateway.Pools { @@ -121,12 +106,20 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e // Create clients for each OpenAI provider client := &Client{ - Provider: *selectedProvider, - PoolName: poolName, - baseURL: defaultBaseURL, - payload: payload, - organization: defaultOrganization, - httpClient: HTTPClient(), + Provider: *selectedProvider, + PoolName: poolName, + baseURL: defaultBaseURL, + payload: payload, + httpClient: HTTPClient(), + } + + v := validator.New() + + if err := v.Struct(client); err != nil { + validationErrors := err.(validator.ValidationErrors) + slog.Error(validationErrors.Error()) + + return nil, validationErrors } return client, nil From ae94197b4c423a5fa50283ec56b4f122c88d26f4 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 19:07:36 -0700 Subject: [PATCH 32/61] #29: Refactor OpenAI chat functionality --- pkg/providers/openai/chat.go | 40 ++++++++++++++++++++++++++ pkg/providers/openai/openaiclient.go | 43 ++-------------------------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index dcd5cec2..e37e6110 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -10,6 +10,8 @@ import ( "net/http" "reflect" "strings" + + "github.com/cloudwego/hertz/pkg/app/client" ) const ( @@ -66,6 +68,31 @@ type ChatUsage struct { TotalTokens int `json:"total_tokens"` } +// Chat sends a chat request to the specified OpenAI model. +// +// Parameters: +// - payload: The user payload for the chat request. +// Returns: +// - *ChatResponse: a pointer to a ChatResponse +// - error: An error if the request failed. +func (c *Client) Chat() (*ChatResponse, error) { + // Create a new chat request + + slog.Info("creating chat request") + + chatRequest := c.CreateChatRequest(c.payload) + + slog.Info("chat request created") + + // Send the chat request + + slog.Info("sending chat request") + + resp, err := c.CreateChatResponse(context.Background(), chatRequest) + + return resp, err +} + func (c *Client) CreateChatRequest(message []byte) *ChatRequest { err := json.Unmarshal(message, &requestBody) if err != nil { @@ -276,3 +303,16 @@ func (c *Client) setModel() string { return c.Provider.Model } + +// HTTPClient returns a new Hertz HTTP client. +// +// It creates a new client using the client.NewClient() function and returns the client. +// If an error occurs during the creation of the client, it logs the error using slog.Error(). +// The function returns the created client or nil if an error occurred. +func HTTPClient() *client.Client { + c, err := client.NewClient() + if err != nil { + slog.Error(err.Error()) + } + return c +} diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 9ab937d5..18de9db6 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -5,7 +5,6 @@ package openai import ( - "context" "errors" "fmt" "log/slog" @@ -57,13 +56,13 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e providerName := "openai" // Read the YAML file - data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") + data, err := os.ReadFile("config.yaml") if err != nil { slog.Error("Failed to read file: %v", err) return nil, err } - slog.Info("config found") + slog.Info("config loaded") // Unmarshal the YAML data into your struct var config providers.GatewayConfig @@ -124,41 +123,3 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e return client, nil } - -// Chat sends a chat request to the specified OpenAI model. -// -// Parameters: -// - payload: The user payload for the chat request. -// Returns: -// - *ChatResponse: a pointer to a ChatResponse -// - error: An error if the request failed. -func (c *Client) Chat() (*ChatResponse, error) { - // Create a new chat request - - slog.Info("creating chat request") - - chatRequest := c.CreateChatRequest(c.payload) - - slog.Info("chat request created") - - // Send the chat request - - slog.Info("sending chat request") - - resp, err := c.CreateChatResponse(context.Background(), chatRequest) - - return resp, err -} - -// HTTPClient returns a new Hertz HTTP client. -// -// It creates a new client using the client.NewClient() function and returns the client. -// If an error occurs during the creation of the client, it logs the error using slog.Error(). -// The function returns the created client or nil if an error occurred. -func HTTPClient() *client.Client { - c, err := client.NewClient() - if err != nil { - slog.Error(err.Error()) - } - return c -} From 45d19594e75c919e6d6129f950f24aa8820feb6f Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 20:17:09 -0700 Subject: [PATCH 33/61] #29: Fix OpenAI client error handling and validation --- pkg/providers/openai/openai_test.go | 7 ++++++- pkg/providers/openai/openaiclient.go | 17 +++-------------- pkg/providers/types.go | 14 +++++++------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 93709a9b..1c65dcea 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "testing" + "log/slog" ) func TestOpenAIClient(t *testing.T) { @@ -30,7 +31,11 @@ func TestOpenAIClient(t *testing.T) { payloadBytes, _ := json.Marshal(payload) - c, _ := OpenAiClient(poolName, modelName, payloadBytes) + c, err := OpenAiClient(poolName, modelName, payloadBytes) + if err != nil { + slog.Error(err.Error()) + return + } resp, _ := c.Chat() diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 18de9db6..f1065073 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -15,7 +15,6 @@ import ( "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" - "github.com/go-playground/validator/v10" ) const ( @@ -56,9 +55,8 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e providerName := "openai" // Read the YAML file - data, err := os.ReadFile("config.yaml") + data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") if err != nil { - slog.Error("Failed to read file: %v", err) return nil, err } @@ -68,7 +66,7 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e var config providers.GatewayConfig err = yaml.Unmarshal(data, &config) if err != nil { - slog.Error("Failed to unmarshal YAML: %v", err) + return nil, err } // Find the pool with the specified name @@ -99,7 +97,7 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e // Check if the provider was found if selectedProvider == nil { - slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) + slog.Error("double check the config.yaml for errors") return nil, fmt.Errorf("provider for model '%s' not found in pool '%s'", modelName, poolName) } @@ -112,14 +110,5 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e httpClient: HTTPClient(), } - v := validator.New() - - if err := v.Struct(client); err != nil { - validationErrors := err.(validator.ValidationErrors) - slog.Error(validationErrors.Error()) - - return nil, validationErrors - } - return client, nil } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index e9ce7901..cdb4609a 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,22 +1,22 @@ package providers type GatewayConfig struct { - Gateway PoolsConfig `yaml:"gateway"` + Gateway PoolsConfig `yaml:"gateway" validate:"required"` } type PoolsConfig struct { - Pools []Pool `yaml:"pools"` + Pools []Pool `yaml:"pools" validate:"required"` } type Pool struct { - Name string `yaml:"pool"` - Balancing string `yaml:"balancing"` - Providers []Provider `yaml:"providers"` + Name string `yaml:"pool" validate:"required"` + Balancing string `yaml:"balancing" validate:"required"` + Providers []Provider `yaml:"providers" validate:"required"` } type Provider struct { - Provider string `yaml:"provider"` + Provider string `yaml:"provider" validate:"required"` Model string `yaml:"model"` - APIKey string `yaml:"api_key"` + APIKey string `yaml:"api_key" validate:"required"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } From 93c34190e9ad441c7a34c84c1fb0877999762f8f Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 20:21:29 -0700 Subject: [PATCH 34/61] #29: Fix OpenAI client error handling and validation --- pkg/providers/openai/openai_test.go | 2 +- pkg/providers/openai/openaiclient.go | 75 ++++++++++++++-------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 1c65dcea..27f3b5d2 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -3,8 +3,8 @@ package openai import ( "encoding/json" "fmt" - "testing" "log/slog" + "testing" ) func TestOpenAIClient(t *testing.T) { diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index f1065073..3d2d8d29 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -19,6 +19,7 @@ import ( const ( defaultBaseURL = "https://api.openai.com/v1" + providerName = "openai" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -52,63 +53,63 @@ type Client struct { // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { - providerName := "openai" - + // Read the YAML file data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read YAML file: %w", err) } slog.Info("config loaded") // Unmarshal the YAML data into your struct var config providers.GatewayConfig - err = yaml.Unmarshal(data, &config) - if err != nil { - return nil, err + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal YAML data: %w", err) } // Find the pool with the specified name - var selectedPool *providers.Pool - for i := range config.Gateway.Pools { - pool := &config.Gateway.Pools[i] - if pool.Name == poolName { - selectedPool = pool - break - } - } - - // Check if the pool was found - if selectedPool == nil { - slog.Error("pool not found") - return nil, fmt.Errorf("pool not found: %s", poolName) - } - - // Find the OpenAI provider in the selected pool with the specified model - var selectedProvider *providers.Provider - for i := range selectedPool.Providers { - provider := &selectedPool.Providers[i] - if provider.Provider == providerName && provider.Model == modelName { - selectedProvider = provider - break - } + selectedPool, err := findPoolByName(config.Gateway.Pools, poolName) + if err != nil { + return nil, fmt.Errorf("failed to find pool: %w", err) } - // Check if the provider was found - if selectedProvider == nil { - slog.Error("double check the config.yaml for errors") - return nil, fmt.Errorf("provider for model '%s' not found in pool '%s'", modelName, poolName) + // Find the OpenAI provider params in the selected pool with the specified model + selectedProvider, err := findProviderByModel(selectedPool.Providers, providerName, modelName) + if err != nil { + return nil, fmt.Errorf("failed to find provider: %w", err) } - // Create clients for each OpenAI provider - client := &Client{ + // Create a new client + c := &Client{ Provider: *selectedProvider, PoolName: poolName, - baseURL: defaultBaseURL, + baseURL: "", // Set the appropriate base URL payload: payload, httpClient: HTTPClient(), } - return client, nil + return c, nil +} + +func findPoolByName(pools []providers.Pool, name string) (*providers.Pool, error) { + for i := range pools { + pool := &pools[i] + if pool.Name == name { + return pool, nil + } + } + + return nil, fmt.Errorf("pool not found: %s", name) +} + +func findProviderByModel(providers []providers.Provider, providerName string, modelName string) (*providers.Provider, error) { + for i := range providers { + provider := &providers[i] + if provider.Provider == providerName && provider.Model == modelName { + return provider, nil + } + } + + return nil, fmt.Errorf("provider not found: %s", modelName) } From 0edd274684bb823a5c780fd2711a40707ca549b3 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 20:22:23 -0700 Subject: [PATCH 35/61] #29: Set the appropriate base URL in OpenAiClient constructor --- pkg/providers/openai/openaiclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 3d2d8d29..04ad687b 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -84,7 +84,7 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e c := &Client{ Provider: *selectedProvider, PoolName: poolName, - baseURL: "", // Set the appropriate base URL + baseURL: defaultBaseURL, // Set the appropriate base URL payload: payload, httpClient: HTTPClient(), } From fd9e21a1129328732d201e6e0d93c97ab1782732 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 20:40:15 -0700 Subject: [PATCH 36/61] #29: Refactor OpenAI provider configuration --- pkg/providers/openai/openaiclient.go | 26 +++++++++++++++++++++----- pkg/providers/types.go | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 04ad687b..71190ddd 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -15,11 +15,12 @@ import ( "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" + "github.com/go-playground/validator/v10" ) const ( defaultBaseURL = "https://api.openai.com/v1" - providerName = "openai" + providerName = "openai" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -53,7 +54,6 @@ type Client struct { // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { - // Read the YAML file data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") if err != nil { @@ -77,7 +77,7 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e // Find the OpenAI provider params in the selected pool with the specified model selectedProvider, err := findProviderByModel(selectedPool.Providers, providerName, modelName) if err != nil { - return nil, fmt.Errorf("failed to find provider: %w", err) + return nil, fmt.Errorf("provider error: %w", err) } // Create a new client @@ -89,6 +89,12 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e httpClient: HTTPClient(), } + v := validator.New() + err = v.Struct(c) + if err != nil { + return nil, fmt.Errorf("failed to validate client: %w", err) + } + return c, nil } @@ -103,13 +109,23 @@ func findPoolByName(pools []providers.Pool, name string) (*providers.Pool, error return nil, fmt.Errorf("pool not found: %s", name) } +// findProviderByModel find provider params in the given config file by the specified provider name and model name. +// +// Parameters: +// - providers: a slice of providers.Provider, the list of providers to search in. +// - providerName: a string, the name of the provider to search for. +// - modelName: a string, the name of the model to search for. +// +// Returns: +// - *providers.Provider: a pointer to the found provider. +// - error: an error indicating whether a provider was found or not. func findProviderByModel(providers []providers.Provider, providerName string, modelName string) (*providers.Provider, error) { for i := range providers { provider := &providers[i] - if provider.Provider == providerName && provider.Model == modelName { + if provider.Name == providerName && provider.Model == modelName { return provider, nil } } - return nil, fmt.Errorf("provider not found: %s", modelName) + return nil, fmt.Errorf("no provider found in config for model: %s", modelName) } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index cdb4609a..eee08ec9 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -14,7 +14,7 @@ type Pool struct { } type Provider struct { - Provider string `yaml:"provider" validate:"required"` + Name string `yaml:"name" validate:"required"` Model string `yaml:"model"` APIKey string `yaml:"api_key" validate:"required"` TimeoutMs int `yaml:"timeout_ms,omitempty"` From c7d341ccbf3ae98d32a14dd9c55377ef01ed5d14 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 20:40:59 -0700 Subject: [PATCH 37/61] #29: comments --- pkg/providers/openai/openaiclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 71190ddd..ad718a92 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -55,7 +55,7 @@ type Client struct { // - error: An error if the client creation failed. func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { // Read the YAML file - data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") + data, err := os.ReadFile("config.yaml") // TODO: How will this be accessed? Does it have to be read each time? if err != nil { return nil, fmt.Errorf("failed to read YAML file: %w", err) } From d922ee97cd157a9d8217ac314b751922f15d3587 Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 21 Dec 2023 10:49:49 -0700 Subject: [PATCH 38/61] #29: Refactor OpenAI client and related functions --- pkg/providers/openai/chat.go | 86 +++------------------------- pkg/providers/openai/openai_test.go | 2 +- pkg/providers/openai/openaiclient.go | 23 +++++--- pkg/providers/types.go | 2 +- 4 files changed, 26 insertions(+), 87 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index e37e6110..ec824425 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -10,8 +10,6 @@ import ( "net/http" "reflect" "strings" - - "github.com/cloudwego/hertz/pkg/app/client" ) const ( @@ -75,7 +73,7 @@ type ChatUsage struct { // Returns: // - *ChatResponse: a pointer to a ChatResponse // - error: An error if the request failed. -func (c *Client) Chat() (*ChatResponse, error) { +func (c *ProviderClient) Chat() (*ChatResponse, error) { // Create a new chat request slog.Info("creating chat request") @@ -93,7 +91,7 @@ func (c *Client) Chat() (*ChatResponse, error) { return resp, err } -func (c *Client) CreateChatRequest(message []byte) *ChatRequest { +func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest { err := json.Unmarshal(message, &requestBody) if err != nil { slog.Error("Error:", err) @@ -168,10 +166,10 @@ type ChatResponse struct { } // CreateChatResponse creates chat Response. -func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { +func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { _ = ctx // keep this for future use - resp, err := c.createChatHTTP(r) + resp, err := c.createChatHTTP(r) // netpoll -> hertz does not yet support tls if err != nil { return nil, err } @@ -181,61 +179,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR return resp, nil } -/* will remove later -func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { - slog.Info("running createChat") - - if payload.StreamingFunc != nil { - payload.Stream = true - } - // Build request payload - payloadBytes, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - // Build request - if c.baseURL == "" { - c.baseURL = defaultBaseURL - } - - req := &protocol.Request{} - res := &protocol.Response{} - req.Header.SetMethod(consts.MethodPost) - req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model)) - req.SetBody(payloadBytes) - req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) - req.Header.Set("Content-Type", "application/json") - - slog.Info("making request") - - // Send request - err = c.httpClient.Do(ctx, req, res) //*client.Client - if err != nil { - slog.Error(err.Error()) - fmt.Println(res.Body()) - return nil, err - } - - slog.Info("request returned") - - defer res.ConnectionClose() // replaced r.Body.Close() - - slog.Info(fmt.Sprintf("%d", res.StatusCode())) - - if res.StatusCode() != http.StatusOK { - msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode()) - - return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 - } - - // Parse response - var response ChatResponse - return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) -} -*/ - -func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { +func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { slog.Info("running createChatHttp") if payload.StreamingFunc != nil { @@ -264,8 +208,7 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") - httpClient := &http.Client{} - resp, err := httpClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { slog.Error(err.Error()) return nil, err @@ -289,30 +232,17 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { return &response, json.NewDecoder(resp.Body).Decode(&response) } -func (c *Client) buildURL(suffix string) string { +func (c *ProviderClient) buildURL(suffix string) string { slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix)) // open ai implement: return fmt.Sprintf("%s%s", c.baseURL, suffix) } -func (c *Client) setModel() string { +func (c *ProviderClient) setModel() string { if c.Provider.Model == "" { return defaultChatModel } return c.Provider.Model } - -// HTTPClient returns a new Hertz HTTP client. -// -// It creates a new client using the client.NewClient() function and returns the client. -// If an error occurs during the creation of the client, it logs the error using slog.Error(). -// The function returns the created client or nil if an error occurred. -func HTTPClient() *client.Client { - c, err := client.NewClient() - if err != nil { - slog.Error(err.Error()) - } - return c -} diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 27f3b5d2..c95ca840 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -31,7 +31,7 @@ func TestOpenAIClient(t *testing.T) { payloadBytes, _ := json.Marshal(payload) - c, err := OpenAiClient(poolName, modelName, payloadBytes) + c, err := Client(poolName, modelName, payloadBytes) if err != nil { slog.Error(err.Error()) return diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index ad718a92..1290af8e 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -8,13 +8,14 @@ import ( "errors" "fmt" "log/slog" + "net/http" "os" + "time" "gopkg.in/yaml.v2" "glide/pkg/providers" - "github.com/cloudwego/hertz/pkg/app/client" "github.com/go-playground/validator/v10" ) @@ -35,13 +36,21 @@ var ( } ) +var httpClient = &http.Client{ + Timeout: time.Second * 60, + Transport: &http.Transport{ + MaxIdleConns: 90, + MaxIdleConnsPerHost: 5, + }, +} + // Client is a client for the OpenAI API. -type Client struct { +type ProviderClient struct { Provider providers.Provider `validate:"required"` PoolName string `validate:"required"` baseURL string `validate:"required"` payload []byte `validate:"required"` - httpClient *client.Client `validate:"required"` + httpClient *http.Client `validate:"required"` } // OpenAiClient creates a new client for the OpenAI API. @@ -53,9 +62,9 @@ type Client struct { // Returns: // - *Client: A pointer to the created client. // - error: An error if the client creation failed. -func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { +func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { // Read the YAML file - data, err := os.ReadFile("config.yaml") // TODO: How will this be accessed? Does it have to be read each time? + data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") // TODO: How will this be accessed? Does it have to be read each time? if err != nil { return nil, fmt.Errorf("failed to read YAML file: %w", err) } @@ -81,12 +90,12 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e } // Create a new client - c := &Client{ + c := &ProviderClient{ Provider: *selectedProvider, PoolName: poolName, baseURL: defaultBaseURL, // Set the appropriate base URL payload: payload, - httpClient: HTTPClient(), + httpClient: httpClient, } v := validator.New() diff --git a/pkg/providers/types.go b/pkg/providers/types.go index eee08ec9..f8e6c8c8 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -14,7 +14,7 @@ type Pool struct { } type Provider struct { - Name string `yaml:"name" validate:"required"` + Name string `yaml:"name" validate:"required"` Model string `yaml:"model"` APIKey string `yaml:"api_key" validate:"required"` TimeoutMs int `yaml:"timeout_ms,omitempty"` From 236a31949cca4da134fabf65af940fa4a226283c Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 21 Dec 2023 10:51:06 -0700 Subject: [PATCH 39/61] #29: chores --- pkg/providers/openai/openaiclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 1290af8e..20b210b4 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -64,7 +64,7 @@ type ProviderClient struct { // - error: An error if the client creation failed. func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { // Read the YAML file - data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") // TODO: How will this be accessed? Does it have to be read each time? + data, err := os.ReadFile("config.yaml") // TODO: How will this be accessed? Does it have to be read each time? if err != nil { return nil, fmt.Errorf("failed to read YAML file: %w", err) } From def24a05cc92e2779d2560961fe79279f917b9ce Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 21 Dec 2023 11:14:37 -0700 Subject: [PATCH 40/61] #29: chores --- pkg/providers/openai/openaiclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 20b210b4..7fa31558 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -93,7 +93,7 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, c := &ProviderClient{ Provider: *selectedProvider, PoolName: poolName, - baseURL: defaultBaseURL, // Set the appropriate base URL + baseURL: defaultBaseURL, payload: payload, httpClient: httpClient, } From 2e3a97e21198eab454954c7b0a3dd4acc7a9a994 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 09:11:14 -0700 Subject: [PATCH 41/61] #29: comment --- pkg/providers/openai/openaiclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 7fa31558..22f8753f 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -64,7 +64,7 @@ type ProviderClient struct { // - error: An error if the client creation failed. func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { // Read the YAML file - data, err := os.ReadFile("config.yaml") // TODO: How will this be accessed? Does it have to be read each time? + data, err := os.ReadFile("config.yaml") // TODO: Replace with struct from pools if err != nil { return nil, fmt.Errorf("failed to read YAML file: %w", err) } From 398829b991b818bd3a8ebab02f5f80d4d851d123 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 10:36:51 -0700 Subject: [PATCH 42/61] #29: create a yaml for provider global configs --- pkg/providers/openai/chat.go | 4 +- pkg/providers/openai/openaiclient.go | 66 +++++++++++++++++++++++----- pkg/providers/providerVars.yaml | 4 ++ pkg/providers/types.go | 9 ++-- 4 files changed, 66 insertions(+), 17 deletions(-) create mode 100644 pkg/providers/providerVars.yaml diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index ec824425..c5b2436c 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -193,7 +194,8 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er // Build request if c.baseURL == "" { - c.baseURL = defaultBaseURL + slog.Error("baseURL not set") + return nil, errors.New("baseURL not set") } reqBody := bytes.NewBuffer(payloadBytes) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 22f8753f..bb9494f2 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -20,8 +20,9 @@ import ( ) const ( - defaultBaseURL = "https://api.openai.com/v1" - providerName = "openai" + providerName = "openai" + providerVarPath = "/Users/max/code/Glide/pkg/providers/providerVars.yaml" + configPath = "/Users/max/code/Glide/config.yaml" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -63,18 +64,19 @@ type ProviderClient struct { // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { - // Read the YAML file - data, err := os.ReadFile("config.yaml") // TODO: Replace with struct from pools + provVars, err := readProviderVars(providerVarPath) if err != nil { - return nil, fmt.Errorf("failed to read YAML file: %w", err) + return nil, fmt.Errorf("failed to read provider vars: %w", err) } - slog.Info("config loaded") + defaultBaseURL, err := getDefaultBaseURL(provVars, providerName) + if err != nil { + return nil, fmt.Errorf("failed to get default base URL: %w", err) + } - // Unmarshal the YAML data into your struct - var config providers.GatewayConfig - if err := yaml.Unmarshal(data, &config); err != nil { - return nil, fmt.Errorf("failed to unmarshal YAML data: %w", err) + config, err := readConfig(configPath) // TODO: replace with struct built in router/pool + if err != nil { + return nil, fmt.Errorf("failed to read config: %w", err) } // Find the pool with the specified name @@ -138,3 +140,47 @@ func findProviderByModel(providers []providers.Provider, providerName string, mo return nil, fmt.Errorf("no provider found in config for model: %s", modelName) } + +func readProviderVars(filePath string) ([]providers.ProviderVars, error) { + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read provider vars file: %w", err) + } + + var provVars []providers.ProviderVars + if err := yaml.Unmarshal(data, &provVars); err != nil { + return nil, fmt.Errorf("failed to unmarshal provider vars data: %w", err) + } + + return provVars, nil +} + +func getDefaultBaseURL(provVars []providers.ProviderVars, providerName string) (string, error) { + providerVarsMap := make(map[string]string) + for _, providerVar := range provVars { + providerVarsMap[providerVar.Name] = providerVar.ChatBaseURL + } + + defaultBaseURL, ok := providerVarsMap[providerName] + if !ok { + return "", fmt.Errorf("default base URL not found for provider: %s", providerName) + } + + return defaultBaseURL, nil +} + +func readConfig(filePath string) (providers.GatewayConfig, error) { + data, err := os.ReadFile(filePath) + if err != nil { + slog.Error("Error:", err) + return providers.GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err) + } + + var config providers.GatewayConfig + if err := yaml.Unmarshal(data, &config); err != nil { + slog.Error("Error:", err) + return providers.GatewayConfig{}, fmt.Errorf("failed to unmarshal config data: %w", err) + } + + return config, nil +} diff --git a/pkg/providers/providerVars.yaml b/pkg/providers/providerVars.yaml new file mode 100644 index 00000000..91f90794 --- /dev/null +++ b/pkg/providers/providerVars.yaml @@ -0,0 +1,4 @@ +- name: openai + chatBaseURL: https://api.openai.com/v1 +- name: cohere + chatBaseURL: https://api.cohere.ai/v1/chat \ No newline at end of file diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f8e6c8c8..b8ba2e7c 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -21,10 +21,7 @@ type Provider struct { DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` } -type RequestBody struct { - Message []struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - MessageHistory []string `json:"messageHistory"` +type ProviderVars struct { + Name string `yaml:"name"` + ChatBaseURL string `yaml:"chatBaseURL"` } From cda3fc9176511128565a9d180833b24926759126 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 10:37:54 -0700 Subject: [PATCH 43/61] #29: comments --- pkg/providers/openai/openaiclient.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index bb9494f2..7beeb178 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -79,13 +79,13 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, return nil, fmt.Errorf("failed to read config: %w", err) } - // Find the pool with the specified name + // Find the pool with the specified name from global config. This may not be necessary selectedPool, err := findPoolByName(config.Gateway.Pools, poolName) if err != nil { return nil, fmt.Errorf("failed to find pool: %w", err) } - // Find the OpenAI provider params in the selected pool with the specified model + // Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary selectedProvider, err := findProviderByModel(selectedPool.Providers, providerName, modelName) if err != nil { return nil, fmt.Errorf("provider error: %w", err) From acac5bdc33ba32f7c8d686a06eff4d7f929a7277 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 10:40:23 -0700 Subject: [PATCH 44/61] #29: comments --- pkg/providers/openai/openaiclient.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 7beeb178..54479c7c 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -79,13 +79,13 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, return nil, fmt.Errorf("failed to read config: %w", err) } - // Find the pool with the specified name from global config. This may not be necessary + // Find the pool with the specified name from global config. This may not be necessary if details are passed directly in struct selectedPool, err := findPoolByName(config.Gateway.Pools, poolName) if err != nil { return nil, fmt.Errorf("failed to find pool: %w", err) } - // Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary + // Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary if details are passed directly in struct selectedProvider, err := findProviderByModel(selectedPool.Providers, providerName, modelName) if err != nil { return nil, fmt.Errorf("provider error: %w", err) From 3f8bf68e82cce7fedd04364cbb53031bf089f60c Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 16:04:23 -0700 Subject: [PATCH 45/61] #29: update http client --- pkg/providers/openai/openaiclient.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 54479c7c..47ff1a1d 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -38,10 +38,10 @@ var ( ) var httpClient = &http.Client{ - Timeout: time.Second * 60, + Timeout: time.Second * 30, Transport: &http.Transport{ - MaxIdleConns: 90, - MaxIdleConnsPerHost: 5, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, }, } From 0065cde97e69182259928e2a317769089b410838 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 16:32:14 -0700 Subject: [PATCH 46/61] #29: Add file path validation and error handling in openaiclient.go --- pkg/providers/openai/openaiclient.go | 33 ++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 47ff1a1d..c7933533 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -10,6 +10,7 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "time" "gopkg.in/yaml.v2" @@ -142,7 +143,21 @@ func findProviderByModel(providers []providers.Provider, providerName string, mo } func readProviderVars(filePath string) ([]providers.ProviderVars, error) { - data, err := os.ReadFile(filePath) + absPath, err := filepath.Abs(filePath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute file path: %w", err) + } + + // Validate that the absolute path is a file + fileInfo, err := os.Stat(absPath) + if err != nil { + return nil, fmt.Errorf("failed to get file info: %w", err) + } + if fileInfo.IsDir() { + return nil, fmt.Errorf("provided path is a directory, not a file") + } + + data, err := os.ReadFile(absPath) if err != nil { return nil, fmt.Errorf("failed to read provider vars file: %w", err) } @@ -170,7 +185,21 @@ func getDefaultBaseURL(provVars []providers.ProviderVars, providerName string) ( } func readConfig(filePath string) (providers.GatewayConfig, error) { - data, err := os.ReadFile(filePath) + absPath, err := filepath.Abs(filePath) + if err != nil { + return providers.GatewayConfig{}, fmt.Errorf("failed to get absolute file path: %w", err) + } + + // Validate that the absolute path is a file + fileInfo, err := os.Stat(absPath) + if err != nil { + return providers.GatewayConfig{}, fmt.Errorf("failed to get file info: %w", err) + } + if fileInfo.IsDir() { + return providers.GatewayConfig{}, fmt.Errorf("provided path is a directory, not a file") + } + + data, err := os.ReadFile(absPath) if err != nil { slog.Error("Error:", err) return providers.GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err) From f5a8ecd5fafa019379f85c82f0757f9c0a7591c7 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 17:49:13 -0700 Subject: [PATCH 47/61] #29: move common helpers to types.go --- pkg/providers/openai/chat.go | 23 +++- pkg/providers/openai/openaiclient.go | 150 ++------------------------- pkg/providers/types.go | 137 ++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 147 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index c5b2436c..7bd832c5 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "glide/pkg/providers" "io" "log/slog" "net/http" @@ -18,6 +19,17 @@ const ( defaultEndpoint = "/chat/completions" ) +// Client is a client for the OpenAI API. +type ProviderClient struct { + Provider providers.Provider `validate:"required"` + PoolName string `validate:"required"` + BaseURL string `validate:"required"` + Payload []byte `validate:"required"` + HttpClient *http.Client `validate:"required"` +} + + + // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` @@ -79,7 +91,7 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { slog.Info("creating chat request") - chatRequest := c.CreateChatRequest(c.payload) + chatRequest := c.CreateChatRequest(c.Payload) slog.Info("chat request created") @@ -93,6 +105,7 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { } func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest { + var requestBody providers.RequestBody err := json.Unmarshal(message, &requestBody) if err != nil { slog.Error("Error:", err) @@ -193,7 +206,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er } // Build request - if c.baseURL == "" { + if c.BaseURL == "" { slog.Error("baseURL not set") return nil, errors.New("baseURL not set") } @@ -210,7 +223,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") - resp, err := c.httpClient.Do(req) + resp, err := c.HttpClient.Do(req) if err != nil { slog.Error(err.Error()) return nil, err @@ -235,10 +248,10 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er } func (c *ProviderClient) buildURL(suffix string) string { - slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix)) + slog.Info("request url: " + fmt.Sprintf("%s%s", c.BaseURL, suffix)) // open ai implement: - return fmt.Sprintf("%s%s", c.baseURL, suffix) + return fmt.Sprintf("%s%s", c.BaseURL, suffix) } func (c *ProviderClient) setModel() string { diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index c7933533..14620770 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -7,13 +7,6 @@ package openai import ( "errors" "fmt" - "log/slog" - "net/http" - "os" - "path/filepath" - "time" - - "gopkg.in/yaml.v2" "glide/pkg/providers" @@ -29,31 +22,8 @@ const ( // ErrEmptyResponse is returned when the OpenAI API returns an empty response. var ( ErrEmptyResponse = errors.New("empty response") - requestBody struct { - Message []struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - MessageHistory []string `json:"messageHistory"` - } ) -var httpClient = &http.Client{ - Timeout: time.Second * 30, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 2, - }, -} - -// Client is a client for the OpenAI API. -type ProviderClient struct { - Provider providers.Provider `validate:"required"` - PoolName string `validate:"required"` - baseURL string `validate:"required"` - payload []byte `validate:"required"` - httpClient *http.Client `validate:"required"` -} // OpenAiClient creates a new client for the OpenAI API. // @@ -65,29 +35,29 @@ type ProviderClient struct { // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { - provVars, err := readProviderVars(providerVarPath) + provVars, err := providers.ReadProviderVars(providerVarPath) if err != nil { return nil, fmt.Errorf("failed to read provider vars: %w", err) } - defaultBaseURL, err := getDefaultBaseURL(provVars, providerName) + defaultBaseURL, err := providers.GetDefaultBaseURL(provVars, providerName) if err != nil { return nil, fmt.Errorf("failed to get default base URL: %w", err) } - config, err := readConfig(configPath) // TODO: replace with struct built in router/pool + config, err := providers.ReadConfig(configPath) // TODO: replace with struct built in router/pool if err != nil { return nil, fmt.Errorf("failed to read config: %w", err) } // Find the pool with the specified name from global config. This may not be necessary if details are passed directly in struct - selectedPool, err := findPoolByName(config.Gateway.Pools, poolName) + selectedPool, err := providers.FindPoolByName(config.Gateway.Pools, poolName) if err != nil { return nil, fmt.Errorf("failed to find pool: %w", err) } // Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary if details are passed directly in struct - selectedProvider, err := findProviderByModel(selectedPool.Providers, providerName, modelName) + selectedProvider, err := providers.FindProviderByModel(selectedPool.Providers, providerName, modelName) if err != nil { return nil, fmt.Errorf("provider error: %w", err) } @@ -96,9 +66,9 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, c := &ProviderClient{ Provider: *selectedProvider, PoolName: poolName, - baseURL: defaultBaseURL, - payload: payload, - httpClient: httpClient, + BaseURL: defaultBaseURL, + Payload: payload, + HttpClient: providers.HTTPClient, } v := validator.New() @@ -109,107 +79,3 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, return c, nil } - -func findPoolByName(pools []providers.Pool, name string) (*providers.Pool, error) { - for i := range pools { - pool := &pools[i] - if pool.Name == name { - return pool, nil - } - } - - return nil, fmt.Errorf("pool not found: %s", name) -} - -// findProviderByModel find provider params in the given config file by the specified provider name and model name. -// -// Parameters: -// - providers: a slice of providers.Provider, the list of providers to search in. -// - providerName: a string, the name of the provider to search for. -// - modelName: a string, the name of the model to search for. -// -// Returns: -// - *providers.Provider: a pointer to the found provider. -// - error: an error indicating whether a provider was found or not. -func findProviderByModel(providers []providers.Provider, providerName string, modelName string) (*providers.Provider, error) { - for i := range providers { - provider := &providers[i] - if provider.Name == providerName && provider.Model == modelName { - return provider, nil - } - } - - return nil, fmt.Errorf("no provider found in config for model: %s", modelName) -} - -func readProviderVars(filePath string) ([]providers.ProviderVars, error) { - absPath, err := filepath.Abs(filePath) - if err != nil { - return nil, fmt.Errorf("failed to get absolute file path: %w", err) - } - - // Validate that the absolute path is a file - fileInfo, err := os.Stat(absPath) - if err != nil { - return nil, fmt.Errorf("failed to get file info: %w", err) - } - if fileInfo.IsDir() { - return nil, fmt.Errorf("provided path is a directory, not a file") - } - - data, err := os.ReadFile(absPath) - if err != nil { - return nil, fmt.Errorf("failed to read provider vars file: %w", err) - } - - var provVars []providers.ProviderVars - if err := yaml.Unmarshal(data, &provVars); err != nil { - return nil, fmt.Errorf("failed to unmarshal provider vars data: %w", err) - } - - return provVars, nil -} - -func getDefaultBaseURL(provVars []providers.ProviderVars, providerName string) (string, error) { - providerVarsMap := make(map[string]string) - for _, providerVar := range provVars { - providerVarsMap[providerVar.Name] = providerVar.ChatBaseURL - } - - defaultBaseURL, ok := providerVarsMap[providerName] - if !ok { - return "", fmt.Errorf("default base URL not found for provider: %s", providerName) - } - - return defaultBaseURL, nil -} - -func readConfig(filePath string) (providers.GatewayConfig, error) { - absPath, err := filepath.Abs(filePath) - if err != nil { - return providers.GatewayConfig{}, fmt.Errorf("failed to get absolute file path: %w", err) - } - - // Validate that the absolute path is a file - fileInfo, err := os.Stat(absPath) - if err != nil { - return providers.GatewayConfig{}, fmt.Errorf("failed to get file info: %w", err) - } - if fileInfo.IsDir() { - return providers.GatewayConfig{}, fmt.Errorf("provided path is a directory, not a file") - } - - data, err := os.ReadFile(absPath) - if err != nil { - slog.Error("Error:", err) - return providers.GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err) - } - - var config providers.GatewayConfig - if err := yaml.Unmarshal(data, &config); err != nil { - slog.Error("Error:", err) - return providers.GatewayConfig{}, fmt.Errorf("failed to unmarshal config data: %w", err) - } - - return config, nil -} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index b8ba2e7c..37fefc6b 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,5 +1,17 @@ package providers +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "net/http" + "time" + + "gopkg.in/yaml.v2" + +) + type GatewayConfig struct { Gateway PoolsConfig `yaml:"gateway" validate:"required"` } @@ -25,3 +37,128 @@ type ProviderVars struct { Name string `yaml:"name"` ChatBaseURL string `yaml:"chatBaseURL"` } + +type RequestBody struct { + Message []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + MessageHistory []string `json:"messageHistory"` + } + +// Variables + +var HTTPClient = &http.Client{ + Timeout: time.Second * 30, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, +} + +// Helper Functions + +func FindPoolByName(pools []Pool, name string) (*Pool, error) { + for i := range pools { + pool := &pools[i] + if pool.Name == name { + return pool, nil + } + } + + return nil, fmt.Errorf("pool not found: %s", name) +} + +// findProviderByModel find provider params in the given config file by the specified provider name and model name. +// +// Parameters: +// - providers: a slice of Provider, the list of providers to search in. +// - providerName: a string, the name of the provider to search for. +// - modelName: a string, the name of the model to search for. +// +// Returns: +// - *Provider: a pointer to the found provider. +// - error: an error indicating whether a provider was found or not. +func FindProviderByModel(providers []Provider, providerName string, modelName string) (*Provider, error) { + for i := range providers { + provider := &providers[i] + if provider.Name == providerName && provider.Model == modelName { + return provider, nil + } + } + + return nil, fmt.Errorf("no provider found in config for model: %s", modelName) +} + +func ReadProviderVars(filePath string) ([]ProviderVars, error) { + absPath, err := filepath.Abs(filePath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute file path: %w", err) + } + + // Validate that the absolute path is a file + fileInfo, err := os.Stat(absPath) + if err != nil { + return nil, fmt.Errorf("failed to get file info: %w", err) + } + if fileInfo.IsDir() { + return nil, fmt.Errorf("provided path is a directory, not a file") + } + + data, err := os.ReadFile(absPath) + if err != nil { + return nil, fmt.Errorf("failed to read provider vars file: %w", err) + } + + var provVars []ProviderVars + if err := yaml.Unmarshal(data, &provVars); err != nil { + return nil, fmt.Errorf("failed to unmarshal provider vars data: %w", err) + } + + return provVars, nil +} + +func GetDefaultBaseURL(provVars []ProviderVars, providerName string) (string, error) { + providerVarsMap := make(map[string]string) + for _, providerVar := range provVars { + providerVarsMap[providerVar.Name] = providerVar.ChatBaseURL + } + + defaultBaseURL, ok := providerVarsMap[providerName] + if !ok { + return "", fmt.Errorf("default base URL not found for provider: %s", providerName) + } + + return defaultBaseURL, nil +} + +func ReadConfig(filePath string) (GatewayConfig, error) { + absPath, err := filepath.Abs(filePath) + if err != nil { + return GatewayConfig{}, fmt.Errorf("failed to get absolute file path: %w", err) + } + + // Validate that the absolute path is a file + fileInfo, err := os.Stat(absPath) + if err != nil { + return GatewayConfig{}, fmt.Errorf("failed to get file info: %w", err) + } + if fileInfo.IsDir() { + return GatewayConfig{}, fmt.Errorf("provided path is a directory, not a file") + } + + data, err := os.ReadFile(absPath) + if err != nil { + slog.Error("Error:", err) + return GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err) + } + + var config GatewayConfig + if err := yaml.Unmarshal(data, &config); err != nil { + slog.Error("Error:", err) + return GatewayConfig{}, fmt.Errorf("failed to unmarshal config data: %w", err) + } + + return config, nil +} + From e8fef9f3ef5211ca50114c27615c9b9c2c393c22 Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 22 Dec 2023 18:32:01 -0700 Subject: [PATCH 48/61] #29: chores --- pkg/providers/openai/chat.go | 39 ++++++++++++++-------------- pkg/providers/openai/openaiclient.go | 3 +-- pkg/providers/types.go | 16 +++++------- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 7bd832c5..da6a5aab 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -6,12 +6,13 @@ import ( "encoding/json" "errors" "fmt" - "glide/pkg/providers" "io" "log/slog" "net/http" "reflect" "strings" + + "glide/pkg/providers" ) const ( @@ -25,11 +26,9 @@ type ProviderClient struct { PoolName string `validate:"required"` BaseURL string `validate:"required"` Payload []byte `validate:"required"` - HttpClient *http.Client `validate:"required"` + HTTPClient *http.Client `validate:"required"` } - - // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` @@ -79,6 +78,20 @@ type ChatUsage struct { TotalTokens int `json:"total_tokens"` } +// ChatResponse is a response to a chat request. +type ChatResponse struct { + ID string `json:"id,omitempty"` + Created float64 `json:"created,omitempty"` + Choices []*ChatChoice `json:"choices,omitempty"` + Model string `json:"model,omitempty"` + Object string `json:"object,omitempty"` + Usage struct { + CompletionTokens float64 `json:"completion_tokens,omitempty"` + PromptTokens float64 `json:"prompt_tokens,omitempty"` + TotalTokens float64 `json:"total_tokens,omitempty"` + } `json:"usage,omitempty"` +} + // Chat sends a chat request to the specified OpenAI model. // // Parameters: @@ -165,25 +178,11 @@ func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest { return chatRequest } -// ChatResponse is a response to a chat request. -type ChatResponse struct { - ID string `json:"id,omitempty"` - Created float64 `json:"created,omitempty"` - Choices []*ChatChoice `json:"choices,omitempty"` - Model string `json:"model,omitempty"` - Object string `json:"object,omitempty"` - Usage struct { - CompletionTokens float64 `json:"completion_tokens,omitempty"` - PromptTokens float64 `json:"prompt_tokens,omitempty"` - TotalTokens float64 `json:"total_tokens,omitempty"` - } `json:"usage,omitempty"` -} - // CreateChatResponse creates chat Response. func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { _ = ctx // keep this for future use - resp, err := c.createChatHTTP(r) // netpoll -> hertz does not yet support tls + resp, err := c.createChatHTTP(r) // netpoll/hertz does not yet support tls if err != nil { return nil, err } @@ -223,7 +222,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) req.Header.Set("Content-Type", "application/json") - resp, err := c.HttpClient.Do(req) + resp, err := c.HTTPClient.Do(req) if err != nil { slog.Error(err.Error()) return nil, err diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 14620770..a22f9c07 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -24,7 +24,6 @@ var ( ErrEmptyResponse = errors.New("empty response") ) - // OpenAiClient creates a new client for the OpenAI API. // // Parameters: @@ -68,7 +67,7 @@ func Client(poolName string, modelName string, payload []byte) (*ProviderClient, PoolName: poolName, BaseURL: defaultBaseURL, Payload: payload, - HttpClient: providers.HTTPClient, + HTTPClient: providers.HTTPClient, } v := validator.New() diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 37fefc6b..48e024f3 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -3,13 +3,12 @@ package providers import ( "fmt" "log/slog" + "net/http" "os" "path/filepath" - "net/http" "time" "gopkg.in/yaml.v2" - ) type GatewayConfig struct { @@ -39,12 +38,12 @@ type ProviderVars struct { } type RequestBody struct { - Message []struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - MessageHistory []string `json:"messageHistory"` - } + Message []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + MessageHistory []string `json:"messageHistory"` +} // Variables @@ -161,4 +160,3 @@ func ReadConfig(filePath string) (GatewayConfig, error) { return config, nil } - From cbcfaf3c0fc94bf6600e2291ddfe5fc106687b37 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 10:31:31 -0600 Subject: [PATCH 49/61] #29: Refactor OpenAI client and chat request creation --- .gitignore | 3 ++- pkg/providers/openai/chat.go | 38 ++++++++++----------------- pkg/providers/openai/openai_test.go | 17 ------------ pkg/providers/openai/openaiclient.go | 39 +++------------------------- pkg/providers/types.go | 10 +++++++ 5 files changed, 30 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index f6da76a8..f7a23953 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ glide tmp coverage.txt precommit.txt -openai_test.go \ No newline at end of file +openai_test.go +pkg/providers/openai/openai_test.go \ No newline at end of file diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index da6a5aab..2b135938 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -22,17 +22,15 @@ const ( // Client is a client for the OpenAI API. type ProviderClient struct { - Provider providers.Provider `validate:"required"` - PoolName string `validate:"required"` BaseURL string `validate:"required"` - Payload []byte `validate:"required"` + UnifiedData providers.UnifiedAPIData `validate:"required"` HTTPClient *http.Client `validate:"required"` } // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` - Messages []*ChatMessage `json:"messages" validate:"required"` + Messages []string `json:"messages" validate:"required"` Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"` TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` @@ -104,7 +102,7 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { slog.Info("creating chat request") - chatRequest := c.CreateChatRequest(c.Payload) + chatRequest := c.CreateChatRequest(c.UnifiedData) slog.Info("chat request created") @@ -117,28 +115,20 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { return resp, err } -func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest { - var requestBody providers.RequestBody - err := json.Unmarshal(message, &requestBody) - if err != nil { - slog.Error("Error:", err) - return nil - } +func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) *ChatRequest { slog.Info("creating chatRequest from payload") - var messages []*ChatMessage - for _, msg := range requestBody.Message { - chatMsg := &ChatMessage{ - Role: msg.Role, - Content: msg.Content, - } - if msg.Role == "user" { - chatMsg.Content += " " + strings.Join(requestBody.MessageHistory, " ") - } - messages = append(messages, chatMsg) + var messages []string + + // Add items from messageHistory first + for _, history := range unifiedData.MessageHistory { + messages = append(messages, history) } + // Add msg variable last + messages = append(messages, unifiedData.Message) + // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value chatRequest := &ChatRequest{ @@ -161,7 +151,7 @@ func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest { } // Use reflection to dynamically assign default parameter values - defaultParams := c.Provider.DefaultParams + defaultParams := unifiedData.Params chatRequestValue := reflect.ValueOf(chatRequest).Elem() chatRequestType := chatRequestValue.Type() @@ -219,7 +209,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er fmt.Println("ReqBody" + reqBody.String()) - req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey) + req.Header.Set("Authorization", "Bearer "+c.UnifiedData.APIKey) req.Header.Set("Content-Type", "application/json") resp, err := c.HTTPClient.Do(req) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index c95ca840..24f486fd 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -12,24 +12,7 @@ func TestOpenAIClient(t *testing.T) { _ = t - poolName := "default" - modelName := "gpt-3.5-turbo" - - payload := map[string]interface{}{ - "message": []map[string]string{ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - { - "role": "user", - "content": "tell me a joke", - }, - }, - "messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"}, - } - payloadBytes, _ := json.Marshal(payload) c, err := Client(poolName, modelName, payloadBytes) if err != nil { diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index a22f9c07..1fe2edc8 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -17,6 +17,7 @@ const ( providerName = "openai" providerVarPath = "/Users/max/code/Glide/pkg/providers/providerVars.yaml" configPath = "/Users/max/code/Glide/config.yaml" + defaultBaseURL = "https://api.openai.com/v1" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. @@ -24,6 +25,7 @@ var ( ErrEmptyResponse = errors.New("empty response") ) + // OpenAiClient creates a new client for the OpenAI API. // // Parameters: @@ -33,48 +35,15 @@ var ( // Returns: // - *Client: A pointer to the created client. // - error: An error if the client creation failed. -func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) { - provVars, err := providers.ReadProviderVars(providerVarPath) - if err != nil { - return nil, fmt.Errorf("failed to read provider vars: %w", err) - } - - defaultBaseURL, err := providers.GetDefaultBaseURL(provVars, providerName) - if err != nil { - return nil, fmt.Errorf("failed to get default base URL: %w", err) - } +func Client(UnifiedData providers.UnifiedAPIData) (*ProviderClient, error) { - config, err := providers.ReadConfig(configPath) // TODO: replace with struct built in router/pool - if err != nil { - return nil, fmt.Errorf("failed to read config: %w", err) - } - - // Find the pool with the specified name from global config. This may not be necessary if details are passed directly in struct - selectedPool, err := providers.FindPoolByName(config.Gateway.Pools, poolName) - if err != nil { - return nil, fmt.Errorf("failed to find pool: %w", err) - } - - // Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary if details are passed directly in struct - selectedProvider, err := providers.FindProviderByModel(selectedPool.Providers, providerName, modelName) - if err != nil { - return nil, fmt.Errorf("provider error: %w", err) - } // Create a new client c := &ProviderClient{ - Provider: *selectedProvider, - PoolName: poolName, BaseURL: defaultBaseURL, - Payload: payload, + UnifiedData: UnifiedData, HTTPClient: providers.HTTPClient, } - v := validator.New() - err = v.Struct(c) - if err != nil { - return nil, fmt.Errorf("failed to validate client: %w", err) - } - return c, nil } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 48e024f3..e213e5fd 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -55,6 +55,16 @@ var HTTPClient = &http.Client{ }, } +type UnifiedAPIData struct { + Provider string `json:"provider"` + Model string `json:"model"` + APIKey string `json:"api_key"` + Params map[string]interface{} `json:"params"` + Message string `json:"message"` + MessageHistory []string `json:"messageHistory"` +} + + // Helper Functions func FindPoolByName(pools []Pool, name string) (*Pool, error) { From 9a14b7192c7dac74ef33650e4716305012c5516e Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 10:51:34 -0600 Subject: [PATCH 50/61] #29: Remove unused dependencies and update dependencies --- go.mod | 9 +-------- go.sum | 18 ------------------ pkg/providers/openai/chat.go | 8 +++----- pkg/providers/openai/openaiclient.go | 2 -- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/go.mod b/go.mod index 834c468f..3fdf7da3 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 - github.com/go-playground/validator/v10 v10.16.0 github.com/spf13/cobra v1.8.0 go.uber.org/goleak v1.3.0 go.uber.org/multierr v1.11.0 @@ -18,9 +17,6 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cloudwego/netpoll v0.5.0 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/henrylee2cn/ameda v1.4.10 // indirect @@ -28,17 +24,14 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/kr/text v0.2.0 // indirect - github.com/leodido/go-urn v1.2.4 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.8.2 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.8.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index cbd30fe6..725f36f7 100644 --- a/go.sum +++ b/go.sum @@ -21,16 +21,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= -github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -54,8 +44,6 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -97,11 +85,7 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -109,8 +93,6 @@ golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 2b135938..83f66934 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -122,9 +122,7 @@ func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) var messages []string // Add items from messageHistory first - for _, history := range unifiedData.MessageHistory { - messages = append(messages, history) - } + messages = append(messages, unifiedData.MessageHistory...) // Add msg variable last messages = append(messages, unifiedData.Message) @@ -244,9 +242,9 @@ func (c *ProviderClient) buildURL(suffix string) string { } func (c *ProviderClient) setModel() string { - if c.Provider.Model == "" { + if c.UnifiedData.Model == "" { return defaultChatModel } - return c.Provider.Model + return c.UnifiedData.Model } diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 1fe2edc8..0ccb30d2 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -6,11 +6,9 @@ package openai import ( "errors" - "fmt" "glide/pkg/providers" - "github.com/go-playground/validator/v10" ) const ( From 906df8dd72b78bcdff8283dc936a970d2ee3ee54 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:19:03 -0600 Subject: [PATCH 51/61] #29: Update Unified Data structure --- go.mod | 1 + go.sum | 2 ++ pkg/providers/openai/chat.go | 45 ++++++++++++++-------------- pkg/providers/openai/openai_test.go | 39 ++++++++++++++++++++++-- pkg/providers/openai/openaiclient.go | 10 ++----- pkg/providers/types.go | 11 ++++--- 6 files changed, 70 insertions(+), 38 deletions(-) diff --git a/go.mod b/go.mod index 3fdf7da3..723b92cc 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 + github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.8.0 go.uber.org/goleak v1.3.0 go.uber.org/multierr v1.11.0 diff --git a/go.sum b/go.sum index 725f36f7..69cc381e 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhY github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 83f66934..a4ab2c92 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -22,29 +22,29 @@ const ( // Client is a client for the OpenAI API. type ProviderClient struct { - BaseURL string `validate:"required"` - UnifiedData providers.UnifiedAPIData `validate:"required"` - HTTPClient *http.Client `validate:"required"` + BaseURL string `validate:"required"` + UnifiedData providers.UnifiedAPIData `validate:"required"` + HTTPClient *http.Client `validate:"required"` } // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { - Model string `json:"model" validate:"required,lowercase"` - Messages []string `json:"messages" validate:"required"` - Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"` - TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` - N int `json:"n,omitempty" validate:"omitempty,gte=1"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` - User interface{} `json:"user,omitempty"` - Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + Model string `json:"model" validate:"required,lowercase"` + Messages []map[string]string `json:"messages" validate:"required"` + Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"` + TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` + N int `json:"n,omitempty" validate:"omitempty,gte=1"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` + User interface{} `json:"user,omitempty"` + Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` // StreamingFunc is a function to be called for each chunk of a streaming response. // Return an error to stop streaming early. @@ -116,10 +116,9 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { } func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) *ChatRequest { - slog.Info("creating chatRequest from payload") - var messages []string + var messages []map[string]string // Add items from messageHistory first messages = append(messages, unifiedData.MessageHistory...) @@ -205,7 +204,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er return nil, err } - fmt.Println("ReqBody" + reqBody.String()) + fmt.Println(reqBody.String()) req.Header.Set("Authorization", "Bearer "+c.UnifiedData.APIKey) req.Header.Set("Content-Type", "application/json") @@ -246,5 +245,5 @@ func (c *ProviderClient) setModel() string { return defaultChatModel } - return c.UnifiedData.Model + return c.UnifiedData.Model } diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 24f486fd..9b453ec3 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -4,17 +4,52 @@ import ( "encoding/json" "fmt" "log/slog" + "os" "testing" + + "glide/pkg/providers" + + "github.com/joho/godotenv" ) func TestOpenAIClient(t *testing.T) { // Initialize the OpenAI client - _ = t + err := godotenv.Load("/Users/max/code/Glide/.env") + if err != nil { + fmt.Println("Error loading .env file" + err.Error()) + } + _ = t + fakeData := providers.UnifiedAPIData{ + Provider: "openai", + Model: "gpt-3.5-turbo", + APIKey: os.Getenv("openai"), + Params: map[string]interface{}{ + "temperature": 0.3, + }, + Message: map[string]string{ + "role": "user", + "content": "Where was it played?", + }, + MessageHistory: []map[string]string{ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "Who won the world series in 2020?", + }, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + }, + } - c, err := Client(poolName, modelName, payloadBytes) + c, err := Client(fakeData) if err != nil { slog.Error(err.Error()) return diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 0ccb30d2..05e9e201 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -8,7 +8,6 @@ import ( "errors" "glide/pkg/providers" - ) const ( @@ -23,7 +22,6 @@ var ( ErrEmptyResponse = errors.New("empty response") ) - // OpenAiClient creates a new client for the OpenAI API. // // Parameters: @@ -34,13 +32,11 @@ var ( // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func Client(UnifiedData providers.UnifiedAPIData) (*ProviderClient, error) { - - // Create a new client c := &ProviderClient{ - BaseURL: defaultBaseURL, - UnifiedData: UnifiedData, - HTTPClient: providers.HTTPClient, + BaseURL: defaultBaseURL, + UnifiedData: UnifiedData, + HTTPClient: providers.HTTPClient, } return c, nil diff --git a/pkg/providers/types.go b/pkg/providers/types.go index e213e5fd..16127539 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -56,15 +56,14 @@ var HTTPClient = &http.Client{ } type UnifiedAPIData struct { - Provider string `json:"provider"` - Model string `json:"model"` - APIKey string `json:"api_key"` + Provider string `json:"provider"` + Model string `json:"model"` + APIKey string `json:"api_key"` Params map[string]interface{} `json:"params"` - Message string `json:"message"` - MessageHistory []string `json:"messageHistory"` + Message map[string]string `json:"message"` + MessageHistory []map[string]string `json:"messageHistory"` } - // Helper Functions func FindPoolByName(pools []Pool, name string) (*Pool, error) { From 19e8f46b239b2e8982f57c7c4882d2daf33e53f7 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:20:59 -0600 Subject: [PATCH 52/61] #29: remove provider --- pkg/providers/openai/openai_test.go | 1 - pkg/providers/types.go | 1 - 2 files changed, 2 deletions(-) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 9b453ec3..f93729d6 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -23,7 +23,6 @@ func TestOpenAIClient(t *testing.T) { _ = t fakeData := providers.UnifiedAPIData{ - Provider: "openai", Model: "gpt-3.5-turbo", APIKey: os.Getenv("openai"), Params: map[string]interface{}{ diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 16127539..742db648 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -56,7 +56,6 @@ var HTTPClient = &http.Client{ } type UnifiedAPIData struct { - Provider string `json:"provider"` Model string `json:"model"` APIKey string `json:"api_key"` Params map[string]interface{} `json:"params"` From 597381caa204c1d86c96c8ef38106ffa1e1a814c Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:30:02 -0600 Subject: [PATCH 53/61] #29: clean up --- pkg/providers/openai/openai_test.go | 62 ----------------------------- 1 file changed, 62 deletions(-) delete mode 100644 pkg/providers/openai/openai_test.go diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go deleted file mode 100644 index f93729d6..00000000 --- a/pkg/providers/openai/openai_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - "log/slog" - "os" - "testing" - - "glide/pkg/providers" - - "github.com/joho/godotenv" -) - -func TestOpenAIClient(t *testing.T) { - // Initialize the OpenAI client - - err := godotenv.Load("/Users/max/code/Glide/.env") - if err != nil { - fmt.Println("Error loading .env file" + err.Error()) - } - - _ = t - - fakeData := providers.UnifiedAPIData{ - Model: "gpt-3.5-turbo", - APIKey: os.Getenv("openai"), - Params: map[string]interface{}{ - "temperature": 0.3, - }, - Message: map[string]string{ - "role": "user", - "content": "Where was it played?", - }, - MessageHistory: []map[string]string{ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - { - "role": "user", - "content": "Who won the world series in 2020?", - }, - { - "role": "assistant", - "content": "The Los Angeles Dodgers won the World Series in 2020.", - }, - }, - } - - c, err := Client(fakeData) - if err != nil { - slog.Error(err.Error()) - return - } - - resp, _ := c.Chat() - - respJSON, _ := json.Marshal(resp) - - fmt.Println(string(respJSON)) -} From 602325400788fda43a5ce9a874cb90d8cf649b2f Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:30:18 -0600 Subject: [PATCH 54/61] #29: clean up --- .gitignore | 3 +-- pkg/providers/openai/chat.go | 2 +- pkg/providers/openai/openaiclient.go | 2 -- pkg/providers/providerVars.yaml | 4 ---- 4 files changed, 2 insertions(+), 9 deletions(-) delete mode 100644 pkg/providers/providerVars.yaml diff --git a/.gitignore b/.gitignore index f7a23953..5b2ff21e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,4 @@ glide tmp coverage.txt precommit.txt -openai_test.go -pkg/providers/openai/openai_test.go \ No newline at end of file +/Users/max/code/Glide/pkg/providers/openai/openai_test.go \ No newline at end of file diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index a4ab2c92..53d8c021 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -126,7 +126,7 @@ func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) // Add msg variable last messages = append(messages, unifiedData.Message) - // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value + // iterate throughunifiedData.Params and add them to the request otherwise leave the default value chatRequest := &ChatRequest{ Model: c.setModel(), diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 05e9e201..6d130149 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -12,8 +12,6 @@ import ( const ( providerName = "openai" - providerVarPath = "/Users/max/code/Glide/pkg/providers/providerVars.yaml" - configPath = "/Users/max/code/Glide/config.yaml" defaultBaseURL = "https://api.openai.com/v1" ) diff --git a/pkg/providers/providerVars.yaml b/pkg/providers/providerVars.yaml deleted file mode 100644 index 91f90794..00000000 --- a/pkg/providers/providerVars.yaml +++ /dev/null @@ -1,4 +0,0 @@ -- name: openai - chatBaseURL: https://api.openai.com/v1 -- name: cohere - chatBaseURL: https://api.cohere.ai/v1/chat \ No newline at end of file From 1d369a38e1e40c79d924d172cbcf591ea4507cb0 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:31:20 -0600 Subject: [PATCH 55/61] #29: lint --- .gitignore | 2 +- pkg/providers/openai/openaiclient.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 5b2ff21e..8b3ecc23 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ glide tmp coverage.txt precommit.txt -/Users/max/code/Glide/pkg/providers/openai/openai_test.go \ No newline at end of file +pkg/providers/openai/openai_test.go \ No newline at end of file diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 6d130149..6311d4a7 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -11,8 +11,8 @@ import ( ) const ( - providerName = "openai" - defaultBaseURL = "https://api.openai.com/v1" + providerName = "openai" + defaultBaseURL = "https://api.openai.com/v1" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. From a9598e9bde010e8afcfa9feb920e0861d5fd8f60 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 26 Dec 2023 11:32:59 -0600 Subject: [PATCH 56/61] #29: remove hertz comment --- pkg/providers/openai/chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 53d8c021..da2c1682 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -169,7 +169,7 @@ func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { _ = ctx // keep this for future use - resp, err := c.createChatHTTP(r) // netpoll/hertz does not yet support tls + resp, err := c.createChatHTTP(r) if err != nil { return nil, err } From 98c029d02a649b33f0e2620ba34bbd2205e33eed Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 27 Dec 2023 10:34:58 -0600 Subject: [PATCH 57/61] #29: pass unified data to chat method --- go.mod | 3 +- go.sum | 1 - pkg/buildAPIRequest.go | 58 ----------------------- pkg/providers/openai/chat.go | 70 +++++++++++----------------- pkg/providers/openai/openaiclient.go | 7 ++- 5 files changed, 31 insertions(+), 108 deletions(-) delete mode 100644 pkg/buildAPIRequest.go diff --git a/go.mod b/go.mod index ec8c613d..7a85ebcb 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 - github.com/joho/godotenv v1.5.1 github.com/hertz-contrib/logger/zap v1.1.0 + github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.8.0 go.uber.org/goleak v1.3.0 go.uber.org/multierr v1.11.0 @@ -26,7 +26,6 @@ require ( github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect - github.com/kr/text v0.2.0 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/testify v1.8.2 // indirect diff --git a/go.sum b/go.sum index 98669d5e..b0bb129a 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/cloudwego/hertz v0.7.3/go.mod h1:WliNtVbwihWHHgAaIQEbVXl0O3aWj0ks1eoP github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/buildAPIRequest.go b/pkg/buildAPIRequest.go deleted file mode 100644 index e8f280e8..00000000 --- a/pkg/buildAPIRequest.go +++ /dev/null @@ -1,58 +0,0 @@ -// this file contains the BuildAPIRequest function which takes in the provider name, params map, and mode and returns the providerConfig map and error -// The providerConfig map can be used to build the API request to the provider -package pkg - -import ( - "errors" - "fmt" - - "glide/pkg/providers" - "glide/pkg/providers/openai" - - "github.com/go-playground/validator/v10" -) - -type ProviderConfigs = pkg.ProviderConfigs - -// Initialize configList - -var configList = map[string]interface{}{ - "openai": openai.OpenAIConfig, -} - -// Create a new validator instance -var validate *validator.Validate = validator.New() - -func BuildAPIRequest(provider string, params map[string]string, mode string) (interface{}, error) { - // provider is the name of the provider, e.g. "openai", params is the map of parameters from the client, - // mode is the mode of the provider, e.g. "chat", configList is the list of provider configurations - var providerConfig map[string]interface{} - - if config, ok := configList[provider].(ProviderConfigs); ok { - if modeConfig, ok := config[mode].(map[string]interface{}); ok { - providerConfig = modeConfig - } - } - - // If the provider is not supported, return an error - if providerConfig == nil { - return nil, errors.New("unsupported provider") - } - - // Build the providerConfig map by iterating over the keys in the providerConfig map and checking if the key exists in the params map - - for key := range providerConfig { - if value, exists := params[key]; exists { - providerConfig[key] = value - } - } - - // Validate the providerConfig map using the validator package - err := validate.Struct(providerConfig) - if err != nil { - // Handle validation error - return nil, fmt.Errorf("validation error: %v", err) - } - // If everything is fine, return the providerConfig and nil error - return providerConfig, nil -} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index a7e0a60f..02f3d101 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -22,26 +22,25 @@ const ( // Client is a client for the OpenAI API. type ProviderClient struct { - BaseURL string `json:"baseURL"` - UnifiedData providers.UnifiedAPIData `json:"unifiedData"` - HTTPClient *http.Client `json:"httpClient"` + BaseURL string `json:"baseURL"` + HTTPClient *http.Client `json:"httpClient"` } // ChatRequest is a request to complete a chat completion.. type ChatRequest struct { Model string `json:"model"` Messages []map[string]string `json:"messages"` - Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"` - TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` - N int `json:"n,omitempty" validate:"omitempty,gte=1"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` + Stream bool `json:"stream,omitempty"` FrequencyPenalty int `json:"frequency_penalty,omitempty"` PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty" validate:"omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` User interface{} `json:"user,omitempty"` - Seed interface{} `json:"seed,omitempty" validate:"omitempty,gte=0"` + Seed interface{} `json:"seed,omitempty"` Tools []string `json:"tools,omitempty"` ToolChoice interface{} `json:"tool_choice,omitempty"` ResponseFormat interface{} `json:"response_format,omitempty"` @@ -69,13 +68,6 @@ type ChatChoice struct { FinishReason string `json:"finish_reason"` } -// ChatUsage is the usage of a chat completion request. -type ChatUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - // ChatResponse is a response to a chat request. type ChatResponse struct { ID string `json:"id,omitempty"` @@ -97,12 +89,12 @@ type ChatResponse struct { // Returns: // - *ChatResponse: a pointer to a ChatResponse // - error: An error if the request failed. -func (c *ProviderClient) Chat() (*ChatResponse, error) { +func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error) { // Create a new chat request slog.Info("creating chat request") - chatRequest := c.CreateChatRequest(c.UnifiedData) + chatRequest := CreateChatRequest(u) slog.Info("chat request created") @@ -110,26 +102,26 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) { slog.Info("sending chat request") - resp, err := c.CreateChatResponse(context.Background(), chatRequest) + resp, err := CreateChatResponse(context.Background(), chatRequest, u) return resp, err } -func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) *ChatRequest { +func CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { slog.Info("creating chatRequest from payload") var messages []map[string]string // Add items from messageHistory first - messages = append(messages, unifiedData.MessageHistory...) + messages = append(messages, u.MessageHistory...) // Add msg variable last - messages = append(messages, unifiedData.Message) + messages = append(messages, u.Message) // iterate throughunifiedData.Params and add them to the request otherwise leave the default value chatRequest := &ChatRequest{ - Model: c.setModel(), + Model: u.Model, Messages: messages, Temperature: 0.8, TopP: 1, @@ -149,7 +141,7 @@ func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) // Use reflection to dynamically assign default parameter values // TODO: refactor to avoid reflection(?) - defaultParams := unifiedData.Params + defaultParams := u.Params chatRequestValue := reflect.ValueOf(chatRequest).Elem() chatRequestType := chatRequestValue.Type() @@ -167,10 +159,10 @@ func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) } // CreateChatResponse creates chat Response. -func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { +func CreateChatResponse(ctx context.Context, r *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { _ = ctx // keep this for future use - resp, err := c.createChatHTTP(r) + resp, err := createChatHTTP(r, u) if err != nil { return nil, err } @@ -180,7 +172,7 @@ func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) return resp, nil } -func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) { +func createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { slog.Info("running createChatHttp") if payload.StreamingFunc != nil { @@ -193,13 +185,13 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er } // Build request - if c.BaseURL == "" { + if defaultBaseURL == "" { slog.Error("baseURL not set") return nil, errors.New("baseURL not set") } reqBody := bytes.NewBuffer(payloadBytes) - req, err := http.NewRequest("POST", c.buildURL(defaultEndpoint), reqBody) + req, err := http.NewRequest("POST", buildURL(defaultEndpoint), reqBody) if err != nil { slog.Error(err.Error()) return nil, err @@ -207,10 +199,10 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er fmt.Println(reqBody.String()) - req.Header.Set("Authorization", "Bearer "+c.UnifiedData.APIKey) + req.Header.Set("Authorization", "Bearer "+u.APIKey) req.Header.Set("Content-Type", "application/json") - resp, err := c.HTTPClient.Do(req) + resp, err := providers.HTTPClient.Do(req) if err != nil { slog.Error(err.Error()) return nil, err @@ -234,17 +226,9 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er return &response, json.NewDecoder(resp.Body).Decode(&response) } -func (c *ProviderClient) buildURL(suffix string) string { - slog.Info("request url: " + fmt.Sprintf("%s%s", c.BaseURL, suffix)) +func buildURL(suffix string) string { + slog.Info("request url: " + fmt.Sprintf("%s%s", defaultBaseURL, suffix)) // open ai implement: - return fmt.Sprintf("%s%s", c.BaseURL, suffix) -} - -func (c *ProviderClient) setModel() string { - if c.UnifiedData.Model == "" { - return defaultChatModel - } - - return c.UnifiedData.Model + return fmt.Sprintf("%s%s", defaultBaseURL, suffix) } diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 6311d4a7..ddbabff0 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -29,12 +29,11 @@ var ( // Returns: // - *Client: A pointer to the created client. // - error: An error if the client creation failed. -func Client(UnifiedData providers.UnifiedAPIData) (*ProviderClient, error) { +func Client() (*ProviderClient, error) { // Create a new client c := &ProviderClient{ - BaseURL: defaultBaseURL, - UnifiedData: UnifiedData, - HTTPClient: providers.HTTPClient, + BaseURL: defaultBaseURL, + HTTPClient: providers.HTTPClient, } return c, nil From a4b41f023cb3f0373af40ad36bdfc0c4047bb867 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 27 Dec 2023 17:06:59 -0600 Subject: [PATCH 58/61] #29: init logging --- pkg/providers/openai/chat.go | 5 ++++- pkg/providers/openai/openaiclient.go | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 02f3d101..00fdd0ab 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -13,6 +13,8 @@ import ( "strings" "glide/pkg/providers" + + "glide/pkg/telemetry" ) const ( @@ -24,6 +26,7 @@ const ( type ProviderClient struct { BaseURL string `json:"baseURL"` HTTPClient *http.Client `json:"httpClient"` + Telemetry *telemetry.Telemetry `json:"telemetry"` } // ChatRequest is a request to complete a chat completion.. @@ -92,7 +95,7 @@ type ChatResponse struct { func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error) { // Create a new chat request - slog.Info("creating chat request") + c.Telemetry.Logger.Info("creating new chat request") chatRequest := CreateChatRequest(u) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index ddbabff0..6988c2a4 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -8,6 +8,8 @@ import ( "errors" "glide/pkg/providers" + + "glide/pkg/telemetry" ) const ( @@ -30,10 +32,20 @@ var ( // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func Client() (*ProviderClient, error) { + + tel, err := telemetry.NewTelemetry(&telemetry.Config{LogConfig: telemetry.NewLogConfig()}) + if err != nil { + return nil, err + } + + tel.Logger.Info("init openai provider client") + + // Create a new client c := &ProviderClient{ BaseURL: defaultBaseURL, HTTPClient: providers.HTTPClient, + Telemetry: tel, } return c, nil From 80a28c9624eae80df8f95be1a8d4da5fee7e073f Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 27 Dec 2023 17:20:58 -0600 Subject: [PATCH 59/61] #29: init logging --- pkg/providers/openai/chat.go | 39 +++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 00fdd0ab..72d9de7b 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -15,6 +15,8 @@ import ( "glide/pkg/providers" "glide/pkg/telemetry" + + "go.uber.org/zap" ) const ( @@ -97,21 +99,22 @@ func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error c.Telemetry.Logger.Info("creating new chat request") - chatRequest := CreateChatRequest(u) + chatRequest := c.CreateChatRequest(u) - slog.Info("chat request created") + c.Telemetry.Logger.Info("chat request created") // Send the chat request slog.Info("sending chat request") - resp, err := CreateChatResponse(context.Background(), chatRequest, u) + resp, err := c.CreateChatResponse(context.Background(), chatRequest, u) return resp, err } -func CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { - slog.Info("creating chatRequest from payload") +func (c *ProviderClient) CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { + + c.Telemetry.Logger.Info("creating chatRequest from payload") var messages []map[string]string @@ -158,14 +161,16 @@ func CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { } } + // c.Telemetry.Logger.Info("chatRequest created", zap.Any("chatRequest body", chatRequest)) + return chatRequest } // CreateChatResponse creates chat Response. -func CreateChatResponse(ctx context.Context, r *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { +func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { _ = ctx // keep this for future use - resp, err := createChatHTTP(r, u) + resp, err := c.createChatHTTP(r, u) if err != nil { return nil, err } @@ -175,8 +180,9 @@ func CreateChatResponse(ctx context.Context, r *ChatRequest, u *providers.Unifie return resp, nil } -func createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { - slog.Info("running createChatHttp") +func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { + + c.Telemetry.Logger.Info("running createChatHttp") if payload.StreamingFunc != nil { payload.Stream = true @@ -189,39 +195,36 @@ func createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatRes // Build request if defaultBaseURL == "" { - slog.Error("baseURL not set") + c.Telemetry.Logger.Error("defaultBaseURL not set") return nil, errors.New("baseURL not set") } reqBody := bytes.NewBuffer(payloadBytes) req, err := http.NewRequest("POST", buildURL(defaultEndpoint), reqBody) if err != nil { - slog.Error(err.Error()) + c.Telemetry.Logger.Error(err.Error()) return nil, err } - fmt.Println(reqBody.String()) - req.Header.Set("Authorization", "Bearer "+u.APIKey) req.Header.Set("Content-Type", "application/json") resp, err := providers.HTTPClient.Do(req) if err != nil { - slog.Error(err.Error()) + c.Telemetry.Logger.Error(err.Error()) return nil, err } defer resp.Body.Close() - slog.Info(fmt.Sprintf("%d", resp.StatusCode)) + c.Telemetry.Logger.Info("Response Code: ", zap.String("response_code", fmt.Sprintf("%d", resp.StatusCode))) if resp.StatusCode != http.StatusOK { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - slog.Error(err.Error()) + c.Telemetry.Logger.Error(err.Error()) } - bodyString := string(bodyBytes) - slog.Warn(bodyString) + c.Telemetry.Logger.Warn("Response Body: ", zap.String("response_body", string(bodyBytes))) } // Parse response From 8cbe50ec4ccd7c5bcbe1739b510e68325985cc00 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 27 Dec 2023 17:21:38 -0600 Subject: [PATCH 60/61] #29: init logging --- pkg/providers/openai/chat.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 72d9de7b..d69c4120 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" "reflect" "strings" @@ -105,8 +104,6 @@ func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error // Send the chat request - slog.Info("sending chat request") - resp, err := c.CreateChatResponse(context.Background(), chatRequest, u) return resp, err @@ -233,7 +230,6 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.Unifi } func buildURL(suffix string) string { - slog.Info("request url: " + fmt.Sprintf("%s%s", defaultBaseURL, suffix)) // open ai implement: return fmt.Sprintf("%s%s", defaultBaseURL, suffix) From c0b69896049df2216ae44f0e041cf94f72bbc501 Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 27 Dec 2023 17:42:24 -0600 Subject: [PATCH 61/61] #29: lint --- go.mod | 1 - go.sum | 8 -- pkg/providers/openai/chat.go | 26 +++---- pkg/providers/openai/openaiclient.go | 2 - pkg/providers/types.go | 112 --------------------------- 5 files changed, 12 insertions(+), 137 deletions(-) diff --git a/go.mod b/go.mod index 7a85ebcb..48db562f 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( go.uber.org/goleak v1.3.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.26.0 - gopkg.in/yaml.v2 v2.4.0 ) require ( diff --git a/go.sum b/go.sum index b0bb129a..40197f1d 100644 --- a/go.sum +++ b/go.sum @@ -43,10 +43,6 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -104,11 +100,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index d69c4120..3471a9c9 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "reflect" + "strconv" "strings" "glide/pkg/providers" @@ -25,8 +26,8 @@ const ( // Client is a client for the OpenAI API. type ProviderClient struct { - BaseURL string `json:"baseURL"` - HTTPClient *http.Client `json:"httpClient"` + BaseURL string `json:"baseURL"` + HTTPClient *http.Client `json:"httpClient"` Telemetry *telemetry.Telemetry `json:"telemetry"` } @@ -95,7 +96,6 @@ type ChatResponse struct { // - error: An error if the request failed. func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error) { // Create a new chat request - c.Telemetry.Logger.Info("creating new chat request") chatRequest := c.CreateChatRequest(u) @@ -110,7 +110,6 @@ func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error } func (c *ProviderClient) CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { - c.Telemetry.Logger.Info("creating chatRequest from payload") var messages []map[string]string @@ -121,7 +120,8 @@ func (c *ProviderClient) CreateChatRequest(u *providers.UnifiedAPIData) *ChatReq // Add msg variable last messages = append(messages, u.Message) - // iterate throughunifiedData.Params and add them to the request otherwise leave the default value + // Iterate through unifiedData.Params and add them to the request, otherwise leave the default value + defaultParams := u.Params chatRequest := &ChatRequest{ Model: u.Model, @@ -142,16 +142,13 @@ func (c *ProviderClient) CreateChatRequest(u *providers.UnifiedAPIData) *ChatReq ResponseFormat: nil, } - // Use reflection to dynamically assign default parameter values - // TODO: refactor to avoid reflection(?) - defaultParams := u.Params - chatRequestValue := reflect.ValueOf(chatRequest).Elem() chatRequestType := chatRequestValue.Type() for i := 0; i < chatRequestValue.NumField(); i++ { jsonTags := strings.Split(chatRequestType.Field(i).Tag.Get("json"), ",") jsonTag := jsonTags[0] + if value, ok := defaultParams[jsonTag]; ok { fieldValue := chatRequestValue.Field(i) fieldValue.Set(reflect.ValueOf(value)) @@ -171,14 +168,15 @@ func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest, if err != nil { return nil, err } + if len(resp.Choices) == 0 { return nil, ErrEmptyResponse } + return resp, nil } func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { - c.Telemetry.Logger.Info("running createChatHttp") if payload.StreamingFunc != nil { @@ -197,7 +195,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.Unifi } reqBody := bytes.NewBuffer(payloadBytes) - req, err := http.NewRequest("POST", buildURL(defaultEndpoint), reqBody) + req, err := http.NewRequest(http.MethodPost, buildURL(defaultEndpoint), reqBody) if err != nil { c.Telemetry.Logger.Error(err.Error()) return nil, err @@ -213,24 +211,24 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.Unifi } defer resp.Body.Close() - c.Telemetry.Logger.Info("Response Code: ", zap.String("response_code", fmt.Sprintf("%d", resp.StatusCode))) + c.Telemetry.Logger.Info("Response Code: ", zap.String("response_code", strconv.Itoa(resp.StatusCode))) if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) if err != nil { c.Telemetry.Logger.Error(err.Error()) } + c.Telemetry.Logger.Warn("Response Body: ", zap.String("response_body", string(bodyBytes))) } // Parse response var response ChatResponse + return &response, json.NewDecoder(resp.Body).Decode(&response) } func buildURL(suffix string) string { - // open ai implement: return fmt.Sprintf("%s%s", defaultBaseURL, suffix) } diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 6988c2a4..3b9060ba 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -32,7 +32,6 @@ var ( // - *Client: A pointer to the created client. // - error: An error if the client creation failed. func Client() (*ProviderClient, error) { - tel, err := telemetry.NewTelemetry(&telemetry.Config{LogConfig: telemetry.NewLogConfig()}) if err != nil { return nil, err @@ -40,7 +39,6 @@ func Client() (*ProviderClient, error) { tel.Logger.Info("init openai provider client") - // Create a new client c := &ProviderClient{ BaseURL: defaultBaseURL, diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 742db648..2c07565f 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,14 +1,8 @@ package providers import ( - "fmt" - "log/slog" "net/http" - "os" - "path/filepath" "time" - - "gopkg.in/yaml.v2" ) type GatewayConfig struct { @@ -62,109 +56,3 @@ type UnifiedAPIData struct { Message map[string]string `json:"message"` MessageHistory []map[string]string `json:"messageHistory"` } - -// Helper Functions - -func FindPoolByName(pools []Pool, name string) (*Pool, error) { - for i := range pools { - pool := &pools[i] - if pool.Name == name { - return pool, nil - } - } - - return nil, fmt.Errorf("pool not found: %s", name) -} - -// findProviderByModel find provider params in the given config file by the specified provider name and model name. -// -// Parameters: -// - providers: a slice of Provider, the list of providers to search in. -// - providerName: a string, the name of the provider to search for. -// - modelName: a string, the name of the model to search for. -// -// Returns: -// - *Provider: a pointer to the found provider. -// - error: an error indicating whether a provider was found or not. -func FindProviderByModel(providers []Provider, providerName string, modelName string) (*Provider, error) { - for i := range providers { - provider := &providers[i] - if provider.Name == providerName && provider.Model == modelName { - return provider, nil - } - } - - return nil, fmt.Errorf("no provider found in config for model: %s", modelName) -} - -func ReadProviderVars(filePath string) ([]ProviderVars, error) { - absPath, err := filepath.Abs(filePath) - if err != nil { - return nil, fmt.Errorf("failed to get absolute file path: %w", err) - } - - // Validate that the absolute path is a file - fileInfo, err := os.Stat(absPath) - if err != nil { - return nil, fmt.Errorf("failed to get file info: %w", err) - } - if fileInfo.IsDir() { - return nil, fmt.Errorf("provided path is a directory, not a file") - } - - data, err := os.ReadFile(absPath) - if err != nil { - return nil, fmt.Errorf("failed to read provider vars file: %w", err) - } - - var provVars []ProviderVars - if err := yaml.Unmarshal(data, &provVars); err != nil { - return nil, fmt.Errorf("failed to unmarshal provider vars data: %w", err) - } - - return provVars, nil -} - -func GetDefaultBaseURL(provVars []ProviderVars, providerName string) (string, error) { - providerVarsMap := make(map[string]string) - for _, providerVar := range provVars { - providerVarsMap[providerVar.Name] = providerVar.ChatBaseURL - } - - defaultBaseURL, ok := providerVarsMap[providerName] - if !ok { - return "", fmt.Errorf("default base URL not found for provider: %s", providerName) - } - - return defaultBaseURL, nil -} - -func ReadConfig(filePath string) (GatewayConfig, error) { - absPath, err := filepath.Abs(filePath) - if err != nil { - return GatewayConfig{}, fmt.Errorf("failed to get absolute file path: %w", err) - } - - // Validate that the absolute path is a file - fileInfo, err := os.Stat(absPath) - if err != nil { - return GatewayConfig{}, fmt.Errorf("failed to get file info: %w", err) - } - if fileInfo.IsDir() { - return GatewayConfig{}, fmt.Errorf("provided path is a directory, not a file") - } - - data, err := os.ReadFile(absPath) - if err != nil { - slog.Error("Error:", err) - return GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err) - } - - var config GatewayConfig - if err := yaml.Unmarshal(data, &config); err != nil { - slog.Error("Error:", err) - return GatewayConfig{}, fmt.Errorf("failed to unmarshal config data: %w", err) - } - - return config, nil -}