From db38f443a9b0a829803ebf19d4cdc362e946afbc Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 19:56:17 -0700 Subject: [PATCH] #29: tests passing --- .gitignore | 1 + pkg/providers/openai/chat.go | 145 +++++++++++++++---------- pkg/providers/openai/openai_test.go | 38 +++++++ pkg/providers/openai/openaiclient.go | 151 +++++++++------------------ pkg/providers/types.go | 11 +- 5 files changed, 188 insertions(+), 158 deletions(-) create mode 100644 pkg/providers/openai/openai_test.go diff --git a/.gitignore b/.gitignore index 951841c7..012b879f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea dist .env +config.yaml diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 022f1125..44c0fd6c 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "bytes" "context" "encoding/json" @@ -10,6 +9,7 @@ import ( "reflect" "strings" "net/http" + "io" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -127,6 +127,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { } } + fmt.Println(chatRequest) + return chatRequest } @@ -171,7 +173,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR } } - resp, err := c.createChat(ctx, r) + resp, err := c.createChatHttp(ctx, r) if err != nil { return nil, err } @@ -183,6 +185,9 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { + + slog.Info("running createChat") + if payload.StreamingFunc != nil { payload.Stream = true } @@ -202,80 +207,110 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes req.Header.SetMethod(consts.MethodPost) req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model)) req.SetBody(payloadBytes) + req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Content-Type", "application/json") - c.setHeaders(req) // sets additional headers + slog.Info("making request") // Send request err = c.httpClient.Do(ctx, req, res) //*client.Client if err != nil { + slog.Error(err.Error()) + fmt.Println(res.Body()) return nil, err } + slog.Info("request returned") + defer res.ConnectionClose() // replaced r.Body.Close() + + + slog.Info(fmt.Sprintf("%d", res.StatusCode())) + if res.StatusCode() != http.StatusOK { msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode()) return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 } - if payload.StreamingFunc != nil { - return parseStreamingChatResponse(ctx, res, payload) - } + // Parse response var response ChatResponse return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) } -func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, payload *ChatRequest) (*ChatResponse, error) { - scanner := bufio.NewScanner(bytes.NewReader(r.Body())) - responseChan := make(chan StreamedChatResponsePayload) - go func() { - defer close(responseChan) - for scanner.Scan() { - line := scanner.Text() - if line == "" { - continue - } - if !strings.HasPrefix(line, "data:") { - slog.Warn("unexpected line:" + line) - } - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - return - } - var streamPayload StreamedChatResponsePayload - err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload) - if err != nil { - slog.Error("failed to decode stream payload: %v", err) - } - responseChan <- streamPayload - } - if err := scanner.Err(); err != nil { - slog.Error("issue scanning response:", err) - } - }() +func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { + slog.Info("running createChatHttp") - // Parse response - response := ChatResponse{ - Choices: []*ChatChoice{ - {}, - }, + if payload.StreamingFunc != nil { + payload.Stream = true + } + // Build request payload + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err } - for streamResponse := range responseChan { - if len(streamResponse.Choices) == 0 { - continue - } - chunk := []byte(streamResponse.Choices[0].Delta.Content) - response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content - response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason - - if payload.StreamingFunc != nil { - err := payload.StreamingFunc(ctx, chunk) - if err != nil { - return nil, fmt.Errorf("streaming func returned an error: %w", err) - } - } + // Build request + if c.baseURL == "" { + c.baseURL = defaultBaseURL + } + + reqBody := bytes.NewBuffer(payloadBytes) + req, err := http.NewRequest("POST", c.buildURL("/chat/completions", c.Provider.Model), reqBody) + if err != nil { + slog.Error(err.Error()) + return nil, err } - return &response, nil -} \ No newline at end of file + + req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + slog.Error(err.Error()) + return nil, err + } + defer resp.Body.Close() + + slog.Info(fmt.Sprintf("%d", resp.StatusCode)) + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + slog.Error(err.Error()) +} + bodyString := string(bodyBytes) + slog.Info(bodyString) + + // Parse response + var response ChatResponse + return &response, json.NewDecoder(resp.Body).Decode(&response) +} + +func IsAzure(apiType APIType) bool { + return apiType == APITypeAzure || apiType == APITypeAzureAD +} + + +func (c *Client) buildURL(suffix string, model string) string { + if IsAzure(c.apiType) { + return c.buildAzureURL(suffix, model) + } + + slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix)) + + // open ai implement: + return fmt.Sprintf("%s%s", c.baseURL, suffix) +} + +func (c *Client) buildAzureURL(suffix string, model string) string { + baseURL := c.baseURL + baseURL = strings.TrimRight(baseURL, "/") + + // azure example url: + // /openai/deployments/{model}/chat/completions?api-version={api_version} + return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", + baseURL, model, suffix, c.apiVersion, + ) +} diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go new file mode 100644 index 00000000..57be6eef --- /dev/null +++ b/pkg/providers/openai/openai_test.go @@ -0,0 +1,38 @@ +package openai + +import ( + "testing" + "encoding/json" + "fmt" +) + + + +func TestOpenAIClient(t *testing.T) { + // Initialize the OpenAI client + + poolName := "default" + modelName := "gpt-3.5-turbo" + + payload := map[string]interface{}{ + "message": []map[string]string{ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "tell me a joke", + }, + }, + "messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"}, + } + + payloadBytes, _ := json.Marshal(payload) + + c := &Client{} + + resp, err := c.Run(poolName, modelName, payloadBytes) + + fmt.Println(resp, err) +} \ No newline at end of file diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 7a1b3ba8..758a7c36 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,38 +1,35 @@ package openai import ( + "context" "errors" "fmt" - "strings" - "log/slog" "gopkg.in/yaml.v2" + "log/slog" "os" - "context" - //"Glide/pkg/providers" + "glide/pkg/providers" "github.com/cloudwego/hertz/pkg/app/client" - "github.com/cloudwego/hertz/pkg/protocol" - ) const ( defaultBaseURL = "https://api.openai.com/v1" - defaultOrganization = "" - defaultFunctionCallBehavior = "auto" + defaultOrganization = "" ) // ErrEmptyResponse is returned when the OpenAI API returns an empty response. var ( ErrEmptyResponse = errors.New("empty response") - requestBody struct { - Message []struct { + requestBody struct { + Message []struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` MessageHistory []string `json:"messageHistory"` } ) + type APIType string const ( @@ -40,83 +37,73 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" ) - -type GatewayConfig struct { - Pools []Pool `yaml:"pools"` -} - -type Pool struct { - Name string `yaml:"name"` - Balancing string `yaml:"balancing"` - Providers []Provider `yaml:"providers"` -} - -type Provider struct { - Name string `yaml:"name"` - Provider string `yaml:"provider"` - Model string `yaml:"model"` - ApiKey string `yaml:"api_key"` - TimeoutMs int `yaml:"timeout_ms,omitempty"` - DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` -} - - - // Client is a client for the OpenAI API. type Client struct { - Provider Provider + Provider providers.Provider baseURL string organization string apiType APIType httpClient *client.Client // required when APIType is APITypeAzure or APITypeAzureAD - apiVersion string + apiVersion string } -// Option is an option for the OpenAI client. -type Option func(*Client) error -func (c *Client) Run() (*ChatResponse, error) { - - c = &Client{} +func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { - c, err := c.Init("pool1", "gpt-3.5-turbo", "openai") + c, err := c.NewClient(poolName, modelName) if err != nil { slog.Error("Error:" + err.Error()) return nil, err } // Create a new chat request - chatRequest := c.CreateChatRequest([]byte("hello world")) + + slog.Info("creating chat request") + + chatRequest := c.CreateChatRequest(payload) + + slog.Info("chat request created") // Send the chat request + slog.Info("sending chat request") + resp, err := c.CreateChatResponse(context.Background(), chatRequest) return resp, err } - -func (c *Client) Init(poolName string, modelName string, providerName string) (*Client, error) { +func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Returns a []*Client of OpenAI + // modelName is determined by the model pool + // poolName is determined by the route the request came from + + var providerName = "openai" // Read the YAML file - data, err := os.ReadFile("config.yaml") + data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") + if err != nil { slog.Error("Failed to read file: %v", err) + return nil, err } + slog.Info("config found") + // Unmarshal the YAML data into your struct - var config GatewayConfig + var config providers.GatewayConfig err = yaml.Unmarshal(data, &config) if err != nil { slog.Error("Failed to unmarshal YAML: %v", err) } + fmt.Println(config) + // Find the pool with the specified name - var selectedPool *Pool - for _, pool := range config.Pools { + var selectedPool *providers.Pool + for _, pool := range config.Gateway.Pools { if pool.Name == poolName { selectedPool = &pool break @@ -130,30 +117,30 @@ func (c *Client) Init(poolName string, modelName string, providerName string) (* } // Find the OpenAI provider in the selected pool with the specified model - var selectedProvider *Provider - for _, provider := range selectedPool.Providers { - if provider.Name == providerName && provider.Model == modelName { - selectedProvider = &provider - break - } - } - - // Check if the provider was found - if selectedProvider == nil { - slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) + var selectedProvider *providers.Provider + for _, provider := range selectedPool.Providers { + if provider.Provider == providerName && provider.Model == modelName { + selectedProvider = &provider + break + } + } + + // Check if the provider was found + if selectedProvider == nil { + slog.Error("provider for model '%s' not found in pool '%s'", modelName, poolName) return nil, fmt.Errorf("provider for model '%s' not found in pool '%s'", modelName, poolName) - } + } // Create clients for each OpenAI provider client := &Client{ - Provider: *selectedProvider, - organization: defaultOrganization, - apiType: APITypeOpenAI, - httpClient: HttpClient(), - } + Provider: *selectedProvider, + organization: defaultOrganization, + apiType: APITypeOpenAI, + httpClient: HttpClient(), + } return client, nil - + } func HttpClient() *client.Client { @@ -165,38 +152,4 @@ func HttpClient() *client.Client { return c } -func IsAzure(apiType APIType) bool { - return apiType == APITypeAzure || apiType == APITypeAzureAD -} - -func (c *Client) setHeaders(req *protocol.Request) { - req.Header.Set("Content-Type", "application/json") - if c.apiType == APITypeOpenAI || c.apiType == APITypeAzureAD { - req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey) - } else { - req.Header.Set("api-key", c.Provider.ApiKey) - } - if c.organization != "" { - req.Header.Set("OpenAI-Organization", c.organization) - } -} -func (c *Client) buildURL(suffix string, model string) string { - if IsAzure(c.apiType) { - return c.buildAzureURL(suffix, model) - } - - // open ai implement: - return fmt.Sprintf("%s%s", c.baseURL, suffix) -} - -func (c *Client) buildAzureURL(suffix string, model string) string { - baseURL := c.baseURL - baseURL = strings.TrimRight(baseURL, "/") - - // azure example url: - // /openai/deployments/{model}/chat/completions?api-version={api_version} - return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s", - baseURL, model, suffix, c.apiVersion, - ) -} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 0f3a1262..7e4dd8e4 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,19 +1,22 @@ -package provider +package providers + type GatewayConfig struct { + Gateway PoolsConfig `yaml:"gateway"` +} +type PoolsConfig struct { Pools []Pool `yaml:"pools"` } type Pool struct { - Name string `yaml:"name"` + Name string `yaml:"pool"` Balancing string `yaml:"balancing"` Providers []Provider `yaml:"providers"` } type Provider struct { - Name string `yaml:"name"` Provider string `yaml:"provider"` Model string `yaml:"model"` - APIKey string `yaml:"api_key"` + ApiKey string `yaml:"api_key"` TimeoutMs int `yaml:"timeout_ms,omitempty"` DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` }