Skip to content

Commit

Permalink
#29: tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 20, 2023
1 parent a5caa77 commit db38f44
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 158 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
dist
.env
config.yaml
145 changes: 90 additions & 55 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package openai

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand All @@ -10,6 +9,7 @@ import (
"reflect"
"strings"
"net/http"
"io"

"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand Down Expand Up @@ -127,6 +127,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
}
}

fmt.Println(chatRequest)

return chatRequest
}

Expand Down Expand Up @@ -171,7 +173,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR
}
}

resp, err := c.createChat(ctx, r)
resp, err := c.createChatHttp(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -183,6 +185,9 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR


func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) {

slog.Info("running createChat")

if payload.StreamingFunc != nil {
payload.Stream = true
}
Expand All @@ -202,80 +207,110 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes
req.Header.SetMethod(consts.MethodPost)
req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model))
req.SetBody(payloadBytes)
req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey)
req.Header.Set("Content-Type", "application/json")

c.setHeaders(req) // sets additional headers
slog.Info("making request")

// Send request
err = c.httpClient.Do(ctx, req, res) //*client.Client
if err != nil {
slog.Error(err.Error())
fmt.Println(res.Body())
return nil, err
}

slog.Info("request returned")

defer res.ConnectionClose() // replaced r.Body.Close()



slog.Info(fmt.Sprintf("%d", res.StatusCode()))

if res.StatusCode() != http.StatusOK {
msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode())

return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113
}
if payload.StreamingFunc != nil {
return parseStreamingChatResponse(ctx, res, payload)
}

// Parse response
var response ChatResponse
return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response)
}

func parseStreamingChatResponse(ctx context.Context, r *protocol.Response, payload *ChatRequest) (*ChatResponse, error) {
scanner := bufio.NewScanner(bytes.NewReader(r.Body()))
responseChan := make(chan StreamedChatResponsePayload)
go func() {
defer close(responseChan)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
if !strings.HasPrefix(line, "data:") {
slog.Warn("unexpected line:" + line)
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return
}
var streamPayload StreamedChatResponsePayload
err := json.NewDecoder(bytes.NewReader([]byte(data))).Decode(&streamPayload)
if err != nil {
slog.Error("failed to decode stream payload: %v", err)
}
responseChan <- streamPayload
}
if err := scanner.Err(); err != nil {
slog.Error("issue scanning response:", err)
}
}()
func (c *Client) createChatHttp(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) {
slog.Info("running createChatHttp")

// Parse response
response := ChatResponse{
Choices: []*ChatChoice{
{},
},
if payload.StreamingFunc != nil {
payload.Stream = true
}
// Build request payload
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}

for streamResponse := range responseChan {
if len(streamResponse.Choices) == 0 {
continue
}
chunk := []byte(streamResponse.Choices[0].Delta.Content)
response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content
response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason

if payload.StreamingFunc != nil {
err := payload.StreamingFunc(ctx, chunk)
if err != nil {
return nil, fmt.Errorf("streaming func returned an error: %w", err)
}
}
// Build request
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}

reqBody := bytes.NewBuffer(payloadBytes)
req, err := http.NewRequest("POST", c.buildURL("/chat/completions", c.Provider.Model), reqBody)
if err != nil {
slog.Error(err.Error())
return nil, err
}
return &response, nil
}

req.Header.Set("Authorization", "Bearer "+c.Provider.ApiKey)
req.Header.Set("Content-Type", "application/json")

httpClient := &http.Client{}
resp, err := httpClient.Do(req)
if err != nil {
slog.Error(err.Error())
return nil, err
}
defer resp.Body.Close()

slog.Info(fmt.Sprintf("%d", resp.StatusCode))

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
slog.Error(err.Error())
}
bodyString := string(bodyBytes)
slog.Info(bodyString)

// Parse response
var response ChatResponse
return &response, json.NewDecoder(resp.Body).Decode(&response)
}

func IsAzure(apiType APIType) bool {
return apiType == APITypeAzure || apiType == APITypeAzureAD
}


func (c *Client) buildURL(suffix string, model string) string {
if IsAzure(c.apiType) {
return c.buildAzureURL(suffix, model)
}

slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix))

// open ai implement:
return fmt.Sprintf("%s%s", c.baseURL, suffix)
}

func (c *Client) buildAzureURL(suffix string, model string) string {
baseURL := c.baseURL
baseURL = strings.TrimRight(baseURL, "/")

// azure example url:
// /openai/deployments/{model}/chat/completions?api-version={api_version}
return fmt.Sprintf("%s/openai/deployments/%s%s?api-version=%s",
baseURL, model, suffix, c.apiVersion,
)
}
38 changes: 38 additions & 0 deletions pkg/providers/openai/openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package openai

import (
"testing"
"encoding/json"
"fmt"
)



func TestOpenAIClient(t *testing.T) {
// Initialize the OpenAI client

poolName := "default"
modelName := "gpt-3.5-turbo"

payload := map[string]interface{}{
"message": []map[string]string{
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "tell me a joke",
},
},
"messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"},
}

payloadBytes, _ := json.Marshal(payload)

c := &Client{}

resp, err := c.Run(poolName, modelName, payloadBytes)

fmt.Println(resp, err)
}
Loading

0 comments on commit db38f44

Please sign in to comment.