From 463401b3a59890905fab3e120ad97d444c596eaf Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 19 Dec 2023 20:03:32 -0700 Subject: [PATCH] #29: tests passing --- pkg/providers/openai/chat.go | 40 +++++++++++++---------------- pkg/providers/openai/openai_test.go | 6 +++-- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 44c0fd6c..bf0cac93 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -5,11 +5,11 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" + "net/http" "reflect" "strings" - "net/http" - "io" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -52,7 +52,6 @@ type ChatMessage struct { // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, // with a maximum length of 64 characters. Name string `json:"name,omitempty"` - } // ChatChoice is a choice in a chat response. @@ -71,7 +70,6 @@ type ChatUsage struct { func (c *Client) CreateChatRequest(message []byte) *ChatRequest { - err := json.Unmarshal(message, &requestBody) if err != nil { slog.Error("Error:", err) @@ -91,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { } // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value - + chatRequest := &ChatRequest{ Model: c.Provider.Model, Messages: messages, @@ -132,7 +130,6 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { return chatRequest } - // ChatResponse is a response to a chat request. type ChatResponse struct { ID string `json:"id,omitempty"` @@ -156,8 +153,8 @@ type StreamedChatResponsePayload struct { Choices []struct { Index float64 `json:"index,omitempty"` Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` } `json:"delta,omitempty"` FinishReason string `json:"finish_reason,omitempty"` } `json:"choices,omitempty"` @@ -172,7 +169,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR r.Model = c.Provider.Model } } - + resp, err := c.createChatHttp(ctx, r) if err != nil { return nil, err @@ -183,11 +180,10 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR return resp, nil } +func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { -func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) { - slog.Info("running createChat") - + if payload.StreamingFunc != nil { payload.Stream = true } @@ -224,8 +220,6 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes defer res.ConnectionClose() // replaced r.Body.Close() - - slog.Info(fmt.Sprintf("%d", res.StatusCode())) if res.StatusCode() != http.StatusOK { @@ -233,7 +227,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113 } - + // Parse response var response ChatResponse return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response) @@ -276,12 +270,15 @@ func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*Cha slog.Info(fmt.Sprintf("%d", resp.StatusCode)) - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - slog.Error(err.Error()) -} - bodyString := string(bodyBytes) - slog.Info(bodyString) + if resp.StatusCode != http.StatusOK { + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + slog.Error(err.Error()) + } + bodyString := string(bodyBytes) + slog.Warn(bodyString) + } // Parse response var response ChatResponse @@ -292,7 +289,6 @@ func IsAzure(apiType APIType) bool { return apiType == APITypeAzure || apiType == APITypeAzureAD } - func (c *Client) buildURL(suffix string, model string) string { if IsAzure(c.apiType) { return c.buildAzureURL(suffix, model) diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index 57be6eef..9badd620 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -32,7 +32,9 @@ func TestOpenAIClient(t *testing.T) { c := &Client{} - resp, err := c.Run(poolName, modelName, payloadBytes) + resp, _ := c.Run(poolName, modelName, payloadBytes) - fmt.Println(resp, err) + respJSON, _ := json.Marshal(resp) + + fmt.Println(string(respJSON)) } \ No newline at end of file