From f38473e6b3753ef2a84bb68c02131296f1404c7b Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 20 Dec 2023 07:36:14 -0700 Subject: [PATCH] #29: update request from defaultParams --- pkg/providers/openai/chat.go | 52 ++++++++++--------- pkg/providers/openai/openai_test.go | 4 +- pkg/providers/openai/openaiclient.go | 76 +++++++++++++++++----------- 3 files changed, 76 insertions(+), 56 deletions(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index f5be4c0b..40ec17f4 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -8,8 +8,8 @@ import ( "io" "log/slog" "net/http" - "reflect" "strings" + "reflect" ) const ( @@ -20,10 +20,10 @@ const ( type ChatRequest struct { Model string `json:"model" validate:"required,lowercase"` Messages []*ChatMessage `json:"messages" validate:"required"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"` - MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"` - N int `json:"n,omitempty" validate:"omitempty,gte=1"` + Temperature float64 `json:"temperature" validate:"omitempty,gte=0,lte=1"` + TopP float64 `json:"top_p" validate:"omitempty,gte=0,lte=1"` + MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` + N int `json:"n" validate:"omitempty,gte=1"` StopWords []string `json:"stop,omitempty"` Stream bool `json:"stream,omitempty" validate:"omitempty, boolean"` FrequencyPenalty int `json:"frequency_penalty,omitempty"` @@ -72,6 +72,8 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { return nil } + slog.Info("creating chatRequest from payload") + var messages []*ChatMessage for _, msg := range requestBody.Message { chatMsg := &ChatMessage{ @@ -87,7 +89,7 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { // iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value chatRequest := &ChatRequest{ - Model: c.Provider.Model, + Model: c.setModel(), Messages: messages, Temperature: 0.8, TopP: 1, @@ -107,21 +109,20 @@ func (c *Client) CreateChatRequest(message []byte) *ChatRequest { // Use reflection to dynamically assign default parameter values defaultParams := c.Provider.DefaultParams - v := reflect.ValueOf(chatRequest).Elem() - t := v.Type() - for i := 0; i < v.NumField(); i++ { - field := t.Field(i) - fieldName := field.Name - defaultValue, ok := defaultParams[fieldName] - if ok && defaultValue != nil { - fieldValue := v.FieldByName(fieldName) - if fieldValue.IsValid() && fieldValue.CanSet() { - fieldValue.Set(reflect.ValueOf(defaultValue)) - } + + chatRequestValue := reflect.ValueOf(chatRequest).Elem() + chatRequestType := chatRequestValue.Type() + + for i := 0; i < chatRequestValue.NumField(); i++ { + jsonTag := chatRequestType.Field(i).Tag.Get("json") + fmt.Println(jsonTag) + if value, ok := defaultParams[jsonTag]; ok { + fieldValue := chatRequestValue.Field(i) + fieldValue.Set(reflect.ValueOf(value)) } } - fmt.Println(chatRequest) + fmt.Println(chatRequest, defaultParams) return chatRequest } @@ -158,13 +159,6 @@ type StreamedChatResponsePayload struct { // CreateChatResponse creates chat Response. func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) { - if r.Model == "" { - if c.Provider.Model == "" { - r.Model = defaultChatModel - } else { - r.Model = c.Provider.Model - } - } _ = ctx // keep this for future use @@ -309,3 +303,11 @@ func (c *Client) buildAzureURL(suffix string, model string) string { baseURL, model, suffix, c.apiVersion, ) } + +func (c *Client) setModel() string { + if c.Provider.Model == "" { + return defaultChatModel + } else { + return c.Provider.Model + } +} diff --git a/pkg/providers/openai/openai_test.go b/pkg/providers/openai/openai_test.go index f070b28f..93709a9b 100644 --- a/pkg/providers/openai/openai_test.go +++ b/pkg/providers/openai/openai_test.go @@ -30,9 +30,9 @@ func TestOpenAIClient(t *testing.T) { payloadBytes, _ := json.Marshal(payload) - c := &Client{} + c, _ := OpenAiClient(poolName, modelName, payloadBytes) - resp, _ := c.Run(poolName, modelName, payloadBytes) + resp, _ := c.Chat() respJSON, _ := json.Marshal(resp) diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go index 8b85b8d4..bced1959 100644 --- a/pkg/providers/openai/openaiclient.go +++ b/pkg/providers/openai/openaiclient.go @@ -1,7 +1,7 @@ // TODO: Explore resource pooling // TODO: Optimize Type use // TODO: Explore Hertz TLS & resource pooling - +// OpenAI package provide a set of functions to interact with the OpenAI API. package openai import ( @@ -46,7 +46,9 @@ const ( // Client is a client for the OpenAI API. type Client struct { Provider providers.Provider + PoolName string baseURL string + payload []byte organization string apiType APIType httpClient *client.Client @@ -55,34 +57,16 @@ type Client struct { apiVersion string } -func (c *Client) Run(poolName string, modelName string, payload []byte) (*ChatResponse, error) { - c, err := c.NewClient(poolName, modelName) - if err != nil { - slog.Error("Error:" + err.Error()) - return nil, err - } - - // Create a new chat request - - slog.Info("creating chat request") - - chatRequest := c.CreateChatRequest(payload) - - slog.Info("chat request created") - - // Send the chat request - - slog.Info("sending chat request") - - resp, err := c.CreateChatResponse(context.Background(), chatRequest) - - return resp, err -} - -func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { - // Returns a []*Client of OpenAI - // modelName is determined by the model pool - // poolName is determined by the route the request came from +// OpenAiClient creates a new client for the OpenAI API. +// +// Parameters: +// - poolName: The name of the pool to connect to. +// - modelName: The name of the model to use. +// +// Returns: +// - *Client: A pointer to the created client. +// - error: An error if the client creation failed. +func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) { providerName := "openai" @@ -139,6 +123,9 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { // Create clients for each OpenAI provider client := &Client{ Provider: *selectedProvider, + PoolName: poolName, + baseURL: defaultBaseURL, + payload: payload, organization: defaultOrganization, apiType: APITypeOpenAI, httpClient: HTTPClient(), @@ -147,6 +134,37 @@ func (c *Client) NewClient(poolName string, modelName string) (*Client, error) { return client, nil } +// Chat sends a chat request to the specified OpenAI model. +// +// Parameters: +// - payload: The user payload for the chat request. +// Returns: +// - *ChatResponse: a pointer to a ChatResponse +// - error: An error if the request failed. +func (c *Client) Chat() (*ChatResponse, error) { + + // Create a new chat request + + slog.Info("creating chat request") + + chatRequest := c.CreateChatRequest(c.payload) + + slog.Info("chat request created") + + // Send the chat request + + slog.Info("sending chat request") + + resp, err := c.CreateChatResponse(context.Background(), chatRequest) + + return resp, err +} + +// HTTPClient returns a new Hertz HTTP client. +// +// It creates a new client using the client.NewClient() function and returns the client. +// If an error occurs during the creation of the client, it logs the error using slog.Error(). +// The function returns the created client or nil if an error occurred. func HTTPClient() *client.Client { c, err := client.NewClient() if err != nil {