Skip to content

Commit

Permalink
#29: Refactor OpenAI client and chat request creation
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 26, 2023
1 parent e8fef9f commit cbcfaf3
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 77 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ glide
tmp
coverage.txt
precommit.txt
openai_test.go
openai_test.go
pkg/providers/openai/openai_test.go
38 changes: 14 additions & 24 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ const (

// Client is a client for the OpenAI API.
type ProviderClient struct {
Provider providers.Provider `validate:"required"`
PoolName string `validate:"required"`
BaseURL string `validate:"required"`
Payload []byte `validate:"required"`
UnifiedData providers.UnifiedAPIData `validate:"required"`
HTTPClient *http.Client `validate:"required"`
}

// ChatRequest is a request to complete a chat completion..
type ChatRequest struct {
Model string `json:"model" validate:"required,lowercase"`
Messages []*ChatMessage `json:"messages" validate:"required"`
Messages []string `json:"messages" validate:"required"`
Temperature float64 `json:"temperature,omitempty" validate:"omitempty,gte=0,lte=1"`
TopP float64 `json:"top_p,omitempty" validate:"omitempty,gte=0,lte=1"`
MaxTokens int `json:"max_tokens,omitempty" validate:"omitempty,gte=0"`
Expand Down Expand Up @@ -104,7 +102,7 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) {

slog.Info("creating chat request")

chatRequest := c.CreateChatRequest(c.Payload)
chatRequest := c.CreateChatRequest(c.UnifiedData)

slog.Info("chat request created")

Expand All @@ -117,28 +115,20 @@ func (c *ProviderClient) Chat() (*ChatResponse, error) {
return resp, err
}

func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest {
var requestBody providers.RequestBody
err := json.Unmarshal(message, &requestBody)
if err != nil {
slog.Error("Error:", err)
return nil
}
func (c *ProviderClient) CreateChatRequest(unifiedData providers.UnifiedAPIData) *ChatRequest {

slog.Info("creating chatRequest from payload")

var messages []*ChatMessage
for _, msg := range requestBody.Message {
chatMsg := &ChatMessage{
Role: msg.Role,
Content: msg.Content,
}
if msg.Role == "user" {
chatMsg.Content += " " + strings.Join(requestBody.MessageHistory, " ")
}
messages = append(messages, chatMsg)
var messages []string

// Add items from messageHistory first
for _, history := range unifiedData.MessageHistory {
messages = append(messages, history)
}

// Add msg variable last
messages = append(messages, unifiedData.Message)

// iterate through self.Provider.DefaultParams and add them to the request otherwise leave the default value

chatRequest := &ChatRequest{
Expand All @@ -161,7 +151,7 @@ func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest {
}

// Use reflection to dynamically assign default parameter values
defaultParams := c.Provider.DefaultParams
defaultParams := unifiedData.Params

chatRequestValue := reflect.ValueOf(chatRequest).Elem()
chatRequestType := chatRequestValue.Type()
Expand Down Expand Up @@ -219,7 +209,7 @@ func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, er

fmt.Println("ReqBody" + reqBody.String())

req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey)
req.Header.Set("Authorization", "Bearer "+c.UnifiedData.APIKey)
req.Header.Set("Content-Type", "application/json")

resp, err := c.HTTPClient.Do(req)
Expand Down
17 changes: 0 additions & 17 deletions pkg/providers/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,7 @@ func TestOpenAIClient(t *testing.T) {

_ = t

poolName := "default"
modelName := "gpt-3.5-turbo"

payload := map[string]interface{}{
"message": []map[string]string{
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "tell me a joke",
},
},
"messageHistory": []string{"Hello there", "How are you?", "I'm good, how about you?"},
}

payloadBytes, _ := json.Marshal(payload)

c, err := Client(poolName, modelName, payloadBytes)
if err != nil {
Expand Down
39 changes: 4 additions & 35 deletions pkg/providers/openai/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ const (
providerName = "openai"
providerVarPath = "/Users/max/code/Glide/pkg/providers/providerVars.yaml"
configPath = "/Users/max/code/Glide/config.yaml"
defaultBaseURL = "https://api.openai.com/v1"
)

// ErrEmptyResponse is returned when the OpenAI API returns an empty response.
var (
ErrEmptyResponse = errors.New("empty response")
)


// OpenAiClient creates a new client for the OpenAI API.
//
// Parameters:
Expand All @@ -33,48 +35,15 @@ var (
// Returns:
// - *Client: A pointer to the created client.
// - error: An error if the client creation failed.
func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) {
provVars, err := providers.ReadProviderVars(providerVarPath)
if err != nil {
return nil, fmt.Errorf("failed to read provider vars: %w", err)
}

defaultBaseURL, err := providers.GetDefaultBaseURL(provVars, providerName)
if err != nil {
return nil, fmt.Errorf("failed to get default base URL: %w", err)
}
func Client(UnifiedData providers.UnifiedAPIData) (*ProviderClient, error) {

config, err := providers.ReadConfig(configPath) // TODO: replace with struct built in router/pool
if err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}

// Find the pool with the specified name from global config. This may not be necessary if details are passed directly in struct
selectedPool, err := providers.FindPoolByName(config.Gateway.Pools, poolName)
if err != nil {
return nil, fmt.Errorf("failed to find pool: %w", err)
}

// Find the OpenAI provider params in the selected pool with the specified model. This may not be necessary if details are passed directly in struct
selectedProvider, err := providers.FindProviderByModel(selectedPool.Providers, providerName, modelName)
if err != nil {
return nil, fmt.Errorf("provider error: %w", err)
}

// Create a new client
c := &ProviderClient{
Provider: *selectedProvider,
PoolName: poolName,
BaseURL: defaultBaseURL,
Payload: payload,
UnifiedData: UnifiedData,
HTTPClient: providers.HTTPClient,
}

v := validator.New()
err = v.Struct(c)
if err != nil {
return nil, fmt.Errorf("failed to validate client: %w", err)
}

return c, nil
}
10 changes: 10 additions & 0 deletions pkg/providers/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ var HTTPClient = &http.Client{
},
}

type UnifiedAPIData struct {
Provider string `json:"provider"`
Model string `json:"model"`
APIKey string `json:"api_key"`
Params map[string]interface{} `json:"params"`
Message string `json:"message"`
MessageHistory []string `json:"messageHistory"`
}


// Helper Functions

func FindPoolByName(pools []Pool, name string) (*Pool, error) {
Expand Down

0 comments on commit cbcfaf3

Please sign in to comment.