Skip to content

Commit

Permalink
#29: update request from defaultParams
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 20, 2023
1 parent 388f01c commit f38473e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 56 deletions.
52 changes: 27 additions & 25 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"io"
"log/slog"
"net/http"
"reflect"
"strings"
"reflect"
)

const (
Expand All @@ -20,10 +20,10 @@ const (
type ChatRequest struct {
Model string `json:"model" validate:"required,lowercase"`
Messages []*ChatMessage `json:"messages" validate:"required"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"`
MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"`
N int `json:"n,omitempty" validate:"omitempty,gte=1"`
Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=1"`
TopP float64 `json:"top_p" validate:"omitempty,gte=0,lte=1"`
MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"`
N int `json:"n" validate:"omitempty,gte=1"`
StopWords []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
Expand Down Expand Up @@ -72,6 +72,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
return nil
}

slog.Info("creating chatRequest from payload")

var messages []*ChatMessage
for _, msg := range requestBody.Message {
chatMsg := &ChatMessage{
Expand All @@ -87,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
// iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value

chatRequest := &ChatRequest{
Model: c.Provider.Model,
Model: c.setModel(),
Messages: messages,
Temperature: 0.8,
TopP: 1,
Expand All @@ -107,21 +109,20 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest {

// Use reflection to dynamically assign default parameter values
defaultParams := c.Provider.DefaultParams
v := reflect.ValueOf(chatRequest).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldName := field.Name
defaultValue, ok := defaultParams[fieldName]
if ok && defaultValue != nil {
fieldValue := v.FieldByName(fieldName)
if fieldValue.IsValid() && fieldValue.CanSet() {
fieldValue.Set(reflect.ValueOf(defaultValue))
}

chatRequestValue := reflect.ValueOf(chatRequest).Elem()
chatRequestType := chatRequestValue.Type()

for i := 0; i < chatRequestValue.NumField(); i++ {
jsonTag := chatRequestType.Field(i).Tag.Get("json")
fmt.Println(jsonTag)
if value, ok := defaultParams[jsonTag]; ok {
fieldValue := chatRequestValue.Field(i)
fieldValue.Set(reflect.ValueOf(value))
}
}

fmt.Println(chatRequest)
fmt.Println(chatRequest, defaultParams)

return chatRequest
}
Expand Down Expand Up @@ -158,13 +159,6 @@ type StreamedChatResponsePayload struct {

// CreateChatResponse creates chat Response.
func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) {
if r.Model == "" {
if c.Provider.Model == "" {
r.Model = defaultChatModel
} else {
r.Model = c.Provider.Model
}
}

_ = ctx // keep this for future use

Expand Down Expand Up @@ -309,3 +303,11 @@ func (c *Client) buildAzureURL(suffix string, model string) string {
baseURL, model, suffix, c.apiVersion,
)
}

func (c *Client) setModel() string {
if c.Provider.Model == "" {
return defaultChatModel
} else {
return c.Provider.Model
}
}
4 changes: 2 additions & 2 deletions pkg/providers/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ func TestOpenAIClient(t *testing.T) {

payloadBytes, _ := json.Marshal(payload)

c := &Client{}
c, _ := OpenAiClient(poolName, modelName, payloadBytes)

resp, _ := c.Run(poolName, modelName, payloadBytes)
resp, _ := c.Chat()

respJSON, _ := json.Marshal(resp)

Expand Down
76 changes: 47 additions & 29 deletions pkg/providers/openai/openaiclient.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// TODO: Explore resource pooling
// TODO: Optimize Type use
// TODO: Explore Hertz TLS & resource pooling

// OpenAI package provide a set of functions to interact with the OpenAI API.
package openai

import (
Expand Down Expand Up @@ -46,7 +46,9 @@ const (
// Client is a client for the OpenAI API.
type Client struct {
Provider providers.Provider
PoolName string
baseURL string
payload []byte
organization string
apiType APIType
httpClient *client.Client
Expand All @@ -55,34 +57,16 @@ type Client struct {
apiVersion string
}

func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) {
c, err := c.NewClient(poolName, modelName)
if err != nil {
slog.Error("Error:" + err.Error())
return nil, err
}

// Create a new chat request

slog.Info("creating chat request")

chatRequest := c.CreateChatRequest(payload)

slog.Info("chat request created")

// Send the chat request

slog.Info("sending chat request")

resp, err := c.CreateChatResponse(context.Background(), chatRequest)

return resp, err
}

func (c *Client) NewClient(poolName string, modelName string) (*Client, error) {
// Returns a []*Client of OpenAI
// modelName is determined by the model pool
// poolName is determined by the route the request came from
// OpenAiClient creates a new client for the OpenAI API.
//
// Parameters:
// - poolName: The name of the pool to connect to.
// - modelName: The name of the model to use.
//
// Returns:
// - *Client: A pointer to the created client.
// - error: An error if the client creation failed.
func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) {

providerName := "openai"

Expand Down Expand Up @@ -139,6 +123,9 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) {
// Create clients for each OpenAI provider
client := &Client{
Provider: *selectedProvider,
PoolName: poolName,
baseURL: defaultBaseURL,
payload: payload,
organization: defaultOrganization,
apiType: APITypeOpenAI,
httpClient: HTTPClient(),
Expand All @@ -147,6 +134,37 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) {
return client, nil
}

// Chat sends a chat request to the specified OpenAI model.
//
// Parameters:
// - payload: The user payload for the chat request.
// Returns:
// - *ChatResponse: a pointer to a ChatResponse
// - error: An error if the request failed.
func (c *Client) Chat() (*ChatResponse, error) {

// Create a new chat request

slog.Info("creating chat request")

chatRequest := c.CreateChatRequest(c.payload)

slog.Info("chat request created")

// Send the chat request

slog.Info("sending chat request")

resp, err := c.CreateChatResponse(context.Background(), chatRequest)

return resp, err
}

// HTTPClient returns a new Hertz HTTP client.
//
// It creates a new client using the client.NewClient() function and returns the client.
// If an error occurs during the creation of the client, it logs the error using slog.Error().
// The function returns the created client or nil if an error occurred.
func HTTPClient() *client.Client {
c, err := client.NewClient()
if err != nil {
Expand Down

0 comments on commit f38473e

Please sign in to comment.