Skip to content

Commit

Permalink
#29: tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 20, 2023
1 parent db38f44 commit 463401b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
40 changes: 18 additions & 22 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"reflect"
"strings"
"net/http"
"io"

"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand Down Expand Up @@ -52,7 +52,6 @@ type ChatMessage struct {
// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,
// with a maximum length of 64 characters.
Name string `json:"name,omitempty"`

}

// ChatChoice is a choice in a chat response.
Expand All @@ -71,7 +70,6 @@ type ChatUsage struct {

func (c *Client) CreateChatRequest(message []byte) *ChatRequest {


err := json.Unmarshal(message, &requestBody)
if err != nil {
slog.Error("Error:", err)
Expand All @@ -91,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
}

// iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value

chatRequest := &ChatRequest{
Model: c.Provider.Model,
Messages: messages,
Expand Down Expand Up @@ -132,7 +130,6 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
return chatRequest
}


// ChatResponse is a response to a chat request.
type ChatResponse struct {
ID string `json:"id,omitempty"`
Expand All @@ -156,8 +153,8 @@ type StreamedChatResponsePayload struct {
Choices []struct {
Index float64 `json:"index,omitempty"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`
Expand All @@ -172,7 +169,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR
r.Model = c.Provider.Model
}
}

resp, err := c.createChatHttp(ctx, r)
if err != nil {
return nil, err
Expand All @@ -183,11 +180,10 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR
return resp, nil
}

func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) {

func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) {

slog.Info("running createChat")

if payload.StreamingFunc != nil {
payload.Stream = true
}
Expand Down Expand Up @@ -224,16 +220,14 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes

defer res.ConnectionClose() // replaced r.Body.Close()



slog.Info(fmt.Sprintf("%d", res.StatusCode()))

if res.StatusCode() != http.StatusOK {
msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode())

return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113
}

// Parse response
var response ChatResponse
return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response)
Expand Down Expand Up @@ -276,12 +270,15 @@ func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*Cha

slog.Info(fmt.Sprintf("%d", resp.StatusCode))

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
slog.Error(err.Error())
}
bodyString := string(bodyBytes)
slog.Info(bodyString)
if resp.StatusCode != http.StatusOK {

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
slog.Error(err.Error())
}
bodyString := string(bodyBytes)
slog.Warn(bodyString)
}

// Parse response
var response ChatResponse
Expand All @@ -292,7 +289,6 @@ func IsAzure(apiType APIType) bool {
return apiType == APITypeAzure || apiType == APITypeAzureAD
}


func (c *Client) buildURL(suffix string, model string) string {
if IsAzure(c.apiType) {
return c.buildAzureURL(suffix, model)
Expand Down
6 changes: 4 additions & 2 deletions pkg/providers/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ func TestOpenAIClient(t *testing.T) {

c := &Client{}

resp, err := c.Run(poolName, modelName, payloadBytes)
resp, _ := c.Run(poolName, modelName, payloadBytes)

fmt.Println(resp, err)
respJSON, _ := json.Marshal(resp)

fmt.Println(string(respJSON))
}

0 comments on commit 463401b

Please sign in to comment.